From d4f7cd50249e98bceddff07aed5da5766c328bde Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Mon, 18 Dec 2023 17:00:50 +0200 Subject: [PATCH] Add support for chat templates (#408) * Add basic support for chat templates * Cleanup * JSDoc improvements * Support conversion of user-defined functions * Cleanup * Fix function creation * Add unit tests for templates * Cleanup * Improve JSDoc * Add missing return types * Add chat templates docs to table of contents * Add support for logical negation * Fix nested logical negation * Add unit tests for logical operators * Add loop variables * Add support for `RuntimeValue` built-in functions * Add unit tests for string instance methods * Fix conversion of normal function to `FunctionValue` * Update object method unit tests * Save chat template to tokenizer_config.json during conversion * Fix `raise_exception` error * Add `!=` operator for booleans * Remember to increment loop index * Cleanup for loop evaluator * Use `is` helper function * Add support for text nodes i.e., non Jinja statements/expressions * Add auto-generated templating tests * Update unit tests * Remove unused function * Add default chat templates * Use repo with up-to-date tokenizer config * Temporarily disable zephyr test * Delete templates.test.js * Move Jinja functionality to `@huggingface/jinja` * Fix template cache type * Update chat template unit tests * Update `@huggingface/jinja` version * Fix default llama2 system prompt usage * Add unit test for llama2 w/o chat template set * Update jinja version * Update jinja version * Add unit test for user-defined chat templates Example from https://discuss.huggingface.co/t/issue-with-llama-2-chat-template-and-out-of-date-documentation/61645/3 * Add `AddedToken` for improved tokenization * Add example usage for chat templates * Add 'first' Metaspace pretokenizer prepend scheme * Formatting * Update wav2vec2 converter special tokens whitespace split * Fix Metaspace pretokenizer split criteria * Update inputs of `PreTokenizerSequence` * Improve Metaspace pretokenizer * Update llama tokenizer tests * Improve handling of legacy llama tokenizer * Re-enable SPM tests * Add static tokenizer test cases * Add llama2 static tests * Allow user to override legacy tokenizer behaviour in `.from_pretrained` * Add legacy tokenizer unit tests * Bump jinja version to 0.1.0 --- package-lock.json | 12 ++ package.json | 3 + scripts/convert.py | 7 + scripts/extra/wav2vec2.py | 4 +- src/tokenizers.js | 419 +++++++++++++++++++++++++++++++------- tests/generate_tests.py | 197 +++++++++++++++--- tests/tensor.test.js | 4 +- tests/tokenizers.test.js | 198 +++++++++++++++++- 8 files changed, 733 insertions(+), 111 deletions(-) diff --git a/package-lock.json b/package-lock.json index 624220e9b..1e3baab76 100644 --- a/package-lock.json +++ b/package-lock.json @@ -27,6 +27,9 @@ }, "optionalDependencies": { "onnxruntime-node": "1.14.0" + }, + "peerDependencies": { + "@huggingface/jinja": "^0.1.0" } }, "node_modules/@ampproject/remapping": { @@ -743,6 +746,15 @@ "node": ">=10.0.0" } }, + "node_modules/@huggingface/jinja": { + "version": "0.1.0", + "resolved": "https://registry.npmjs.org/@huggingface/jinja/-/jinja-0.1.0.tgz", + "integrity": "sha512-NgZ0imvGPHblw+nFJN2eC+so0DmvLSEieldI7gjZZbBUDE80ypG1O+DibdeWne1vQuGBYV/pC3XL//SgxiXC7g==", + "peer": true, + "engines": { + "node": ">=18" + } + }, "node_modules/@istanbuljs/load-nyc-config": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/@istanbuljs/load-nyc-config/-/load-nyc-config-1.1.0.tgz", diff --git a/package.json b/package.json index 3f122bcfe..5f1efd35b 100644 --- a/package.json +++ b/package.json @@ -44,6 +44,9 @@ "optionalDependencies": { "onnxruntime-node": "1.14.0" }, + "peerDependencies": { + "@huggingface/jinja": "^0.1.0" + }, "devDependencies": { "@types/jest": "^29.5.1", "catharsis": "github:xenova/catharsis", diff --git a/scripts/convert.py b/scripts/convert.py index 56a7e61a3..f8aab906a 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -283,6 +283,13 @@ def main(): # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + # To avoid inserting all chat templates into tokenizers.js, we save the chat template + # to the tokenizer_config.json file, and load it when the tokenizer is loaded. + if getattr(tokenizer, 'chat_template', None) is None and \ + getattr(tokenizer, 'use_default_system_prompt', False): + # No chat template specified, and we use the default + setattr(tokenizer, 'chat_template', tokenizer.default_chat_template) + except KeyError: pass # No Tokenizer diff --git a/scripts/extra/wav2vec2.py b/scripts/extra/wav2vec2.py index ed35b0856..d55d3d1d5 100644 --- a/scripts/extra/wav2vec2.py +++ b/scripts/extra/wav2vec2.py @@ -20,8 +20,8 @@ def generate_tokenizer_json(tokenizer): "id": v, "content": k, "single_word": False, - "lstrip": False, - "rstrip": False, + "lstrip": True, + "rstrip": True, "normalized": False, "special": True } diff --git a/src/tokenizers.js b/src/tokenizers.js index 80ccaf851..8ac89d25b 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -41,10 +41,17 @@ import { CharTrie, } from './utils/data-structures.js'; + +/** + * @typedef {Object} TokenizerProperties Additional tokenizer-specific properties. + * @property {boolean} [legacy=false] Whether or not the `legacy` behavior of the tokenizer should be used. + * @typedef {import('./utils/hub.js').PretrainedOptions & TokenizerProperties} PretrainedTokenizerOptions + */ + /** * Loads a tokenizer from the specified path. * @param {string} pretrained_model_name_or_path The path to the tokenizer directory. - * @param {import('./utils/hub.js').PretrainedOptions} options Additional options for loading the tokenizer. + * @param {PretrainedTokenizerOptions} options Additional options for loading the tokenizer. * @returns {Promise} A promise that resolves with information about the loaded tokenizer. */ async function loadTokenizer(pretrained_model_name_or_path, options) { @@ -53,6 +60,11 @@ async function loadTokenizer(pretrained_model_name_or_path, options) { getModelJSON(pretrained_model_name_or_path, 'tokenizer.json', true, options), getModelJSON(pretrained_model_name_or_path, 'tokenizer_config.json', true, options), ]) + + // Override legacy option if `options.legacy` is not null + if (options.legacy !== null) { + info[1].legacy = options.legacy; + } return info; } @@ -214,6 +226,35 @@ function whitespace_split(text) { const PUNCTUATION_REGEX = '\\p{P}\\u0021-\\u002F\\u003A-\\u0040\\u005B-\\u0060\\u007B-\\u007E'; +/** + * Represent a token added by the user on top of the existing Model vocabulary. + * AddedToken can be configured to specify the behavior they should have in various situations like: + * - Whether they should only match single words + * - Whether to include any whitespace on its left or right + */ +class AddedToken { + /** + * Creates a new instance of AddedToken. + * @param {Object} config Added token configuration object. + * @param {string} config.content The content of the added token. + * @param {number} config.id The id of the added token. + * @param {boolean} [config.single_word=false] Whether this token must be a single word or can break words. + * @param {boolean} [config.lstrip=false] Whether this token should strip whitespaces on its left. + * @param {boolean} [config.rstrip=false] Whether this token should strip whitespaces on its right. + * @param {boolean} [config.normalized=false] Whether this token should be normalized. + * @param {boolean} [config.special=false] Whether this token is special. + */ + constructor(config) { + this.content = config.content; + this.id = config.id; + this.single_word = config.single_word ?? false; + this.lstrip = config.lstrip ?? false; + this.rstrip = config.rstrip ?? false; + this.special = config.special ?? false; + this.normalized = config.normalized ?? null; + } +} + /** * Abstract base class for tokenizer models. * @@ -1212,28 +1253,30 @@ class PreTokenizer extends Callable { } /** - * Method that should be implemented by subclasses to define the specific pre-tokenization logic. - * - * @abstract - * @param {string} text The text to pre-tokenize. - * @returns {string[]} The pre-tokenized text. - * @throws {Error} If the method is not implemented in the subclass. - */ - pre_tokenize_text(text) { + * Method that should be implemented by subclasses to define the specific pre-tokenization logic. + * + * @abstract + * @param {string} text The text to pre-tokenize. + * @param {Object} [options] Additional options for the pre-tokenization logic. + * @returns {string[]} The pre-tokenized text. + * @throws {Error} If the method is not implemented in the subclass. + */ + pre_tokenize_text(text, options) { throw Error("pre_tokenize_text should be implemented in subclass.") } /** * Tokenizes the given text into pre-tokens. * @param {string|string[]} text The text or array of texts to pre-tokenize. + * @param {Object} [options] Additional options for the pre-tokenization logic. * @returns {string[]} An array of pre-tokens. */ - pre_tokenize(text) { + pre_tokenize(text, options) { let result = []; if (Array.isArray(text)) { - result = text.map(x => this.pre_tokenize_text(x)) + result = text.map(x => this.pre_tokenize_text(x, options)) } else { - result = this.pre_tokenize_text(text); + result = this.pre_tokenize_text(text, options); } return result.flat(); } @@ -1241,10 +1284,11 @@ class PreTokenizer extends Callable { /** * Alias for {@link PreTokenizer#pre_tokenize}. * @param {string|string[]} text The text or array of texts to pre-tokenize. + * @param {Object} [options] Additional options for the pre-tokenization logic. * @returns {string[]} An array of pre-tokens. */ - _call(text) { - return this.pre_tokenize(text); + _call(text, options) { + return this.pre_tokenize(text, options); } } @@ -1269,9 +1313,10 @@ class BertPreTokenizer extends PreTokenizer { * Tokenizes a single text using the BERT pre-tokenization scheme. * * @param {string} text The text to tokenize. + * @param {Object} [options] Additional options for the pre-tokenization logic. * @returns {string[]} An array of tokens. */ - pre_tokenize_text(text) { + pre_tokenize_text(text, options) { return text.trim().match(this.pattern) || []; } } @@ -1316,9 +1361,10 @@ class ByteLevelPreTokenizer extends PreTokenizer { /** * Tokenizes a single piece of text using byte-level tokenization. * @param {string} text The text to tokenize. + * @param {Object} [options] Additional options for the pre-tokenization logic. * @returns {string[]} An array of tokens. */ - pre_tokenize_text(text) { + pre_tokenize_text(text, options) { // Add a leading space if the option is enabled if (this.add_prefix_space && !text.startsWith(' ')) { text = ' ' + text; @@ -1362,9 +1408,10 @@ class SplitPreTokenizer extends PreTokenizer { /** * Tokenizes text by splitting it using the given pattern. * @param {string} text The text to tokenize. + * @param {Object} [options] Additional options for the pre-tokenization logic. * @returns {string[]} An array of tokens. */ - pre_tokenize_text(text) { + pre_tokenize_text(text, options) { if (this.pattern === null) { return []; } @@ -1395,9 +1442,10 @@ class PunctuationPreTokenizer extends PreTokenizer { /** * Tokenizes text by splitting it using the given pattern. * @param {string} text The text to tokenize. + * @param {Object} [options] Additional options for the pre-tokenization logic. * @returns {string[]} An array of tokens. */ - pre_tokenize_text(text) { + pre_tokenize_text(text, options) { return text.match(this.pattern) || []; } } @@ -1424,9 +1472,10 @@ class DigitsPreTokenizer extends PreTokenizer { /** * Tokenizes text by splitting it using the given pattern. * @param {string} text The text to tokenize. + * @param {Object} [options] Additional options for the pre-tokenization logic. * @returns {string[]} An array of tokens. */ - pre_tokenize_text(text) { + pre_tokenize_text(text, options) { return text.match(this.pattern) || []; } } @@ -1604,6 +1653,7 @@ class Decoder extends Callable { super(); this.config = config; + /** @type {AddedToken[]} */ this.added_tokens = []; this.end_of_word_suffix = null; this.trim_offsets = config.trim_offsets; @@ -1875,7 +1925,7 @@ class ByteLevelDecoder extends Decoder { // continue; // } - if (this.added_tokens.includes(token)) { + if (this.added_tokens.find(x => x.content === token) !== undefined) { if (current_sub_text.length > 0) { sub_texts.push(this.convert_tokens_to_string(current_sub_text)); current_sub_text = []; @@ -1996,6 +2046,7 @@ class MetaspacePreTokenizer extends PreTokenizer { * @param {boolean} config.add_prefix_space Whether to add a prefix space to the first token. * @param {string} config.replacement The character to replace spaces with. * @param {string} [config.str_rep=config.replacement] An optional string representation of the replacement character. + * @param {'first'|'never'|'always'} [config.prepend_scheme='always'] The metaspace prepending scheme. */ constructor(config) { super(); @@ -2003,31 +2054,40 @@ class MetaspacePreTokenizer extends PreTokenizer { this.addPrefixSpace = config.add_prefix_space; this.replacement = config.replacement; this.strRep = config.str_rep || this.replacement; + this.prepend_scheme = config.prepend_scheme ?? 'always'; } /** - * This method takes a list of normalized tokens, replaces spaces with the replacement character, + * This method takes a string, replaces spaces with the replacement character, * adds a prefix space if requested, and returns a new list of tokens. - * @param {string[]|string} normalizedTokens The list of normalized tokens to pre-tokenize. + * @param {string} text The text to pre-tokenize. + * @param {Object} [options] The options for the pre-tokenization. + * @param {number} [options.section_index] The index of the section to pre-tokenize. * @returns {string[]} A new list of pre-tokenized tokens. */ - pre_tokenize(normalizedTokens) { - if (typeof normalizedTokens === 'string') { - // Metaspace acts on a list of tokens. If passing in a string, first split on whitespace - // NOTE: For some reason, metaspace includes trailing whitespace, so we only trim leading whitespace. - // See: https://github.com/huggingface/tokenizers/issues/1250 - normalizedTokens = normalizedTokens.trimStart().split(/\s+/); - } + pre_tokenize_text(text, { + section_index = undefined, + } = {}) { - const result = []; - for (let token of normalizedTokens) { - let normalized = token.replaceAll(' ', this.strRep); - if (this.addPrefixSpace && !normalized.startsWith(this.replacement)) { - normalized = this.strRep + normalized; - } - result.push(normalized); + let normalized = text.replaceAll(' ', this.strRep); + + if ( + // We add a prefix space if: + // (1) The addPrefixSpace option is enabled and the normalized + // token does not already start with the replacement character. + (this.addPrefixSpace && !normalized.startsWith(this.replacement)) + + // and (2) either: + // (a) prepend_scheme is 'always' + // (b) prepend_scheme is 'first' and this is the first section + && ( + this.prepend_scheme === 'always' || + (this.prepend_scheme === 'first' && section_index === 0) + ) + ) { + normalized = this.strRep + normalized; } - return result; + return [normalized]; } } @@ -2134,17 +2194,15 @@ class PreTokenizerSequence extends PreTokenizer { /** * Applies each pre-tokenizer in the sequence to the input text in turn. - * @param {string|string[]} text The text(s) to pre-tokenize. + * @param {string} text The text to pre-tokenize. + * @param {Object} [options] Additional options for the pre-tokenization logic. * @returns {string[]} The pre-tokenized text. */ - pre_tokenize_text(text) { - if (typeof text === 'string') { - text = [text]; - } + pre_tokenize_text(text, options) { // Use reduce to apply each tokenizer to the text return this.tokenizers.reduce((preTokenizedText, tokenizer) => { - return tokenizer.pre_tokenize(preTokenizedText); - }, text); + return tokenizer.pre_tokenize(preTokenizedText, options); + }, [text]); } } @@ -2163,9 +2221,10 @@ class WhitespaceSplit extends PreTokenizer { /** * Pre-tokenizes the input text by splitting it on whitespace characters. * @param {string} text The text to be pre-tokenized. + * @param {Object} [options] Additional options for the pre-tokenization logic. * @returns {string[]} An array of tokens produced by splitting the input text on whitespace. */ - pre_tokenize_text(text) { + pre_tokenize_text(text, options) { return whitespace_split(text); } } @@ -2187,9 +2246,10 @@ class ReplacePreTokenizer extends PreTokenizer { /** * Pre-tokenizes the input text by replacing certain characters. * @param {string} text The text to be pre-tokenized. + * @param {Object} [options] Additional options for the pre-tokenization logic. * @returns {string[]} An array of tokens produced by replacing certain characters. */ - pre_tokenize_text(text) { + pre_tokenize_text(text, options) { if (this.pattern === null) { return [text]; } @@ -2197,8 +2257,20 @@ class ReplacePreTokenizer extends PreTokenizer { } } +const SPECIAL_TOKEN_ATTRIBUTES = [ + 'bos_token', + 'eos_token', + 'unk_token', + 'sep_token', + 'pad_token', + 'cls_token', + 'mask_token', + // additional_special_tokens (TODO) +] export class PreTrainedTokenizer extends Callable { + _default_chat_template = `{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}`; + /** * Create a new PreTrainedTokenizer instance. * @param {Object} tokenizerJSON The JSON of the tokenizer. @@ -2207,6 +2279,8 @@ export class PreTrainedTokenizer extends Callable { constructor(tokenizerJSON, tokenizerConfig) { super(); + this._tokenizer_config = tokenizerConfig; + // Construct parts of the tokenizer from the JSON this.normalizer = Normalizer.fromConfig(tokenizerJSON.normalizer); this.pre_tokenizer = PreTokenizer.fromConfig(tokenizerJSON.pre_tokenizer); @@ -2220,24 +2294,25 @@ export class PreTrainedTokenizer extends Callable { // Add added_tokens to model this.special_tokens = []; this.all_special_ids = []; + + /** @type {AddedToken[]} */ this.added_tokens = []; for (let addedToken of tokenizerJSON.added_tokens) { - let id = addedToken.id; - let content = addedToken.content; - - this.added_tokens.push(content); + const token = new AddedToken(addedToken); + this.added_tokens.push(token); - this.model.tokens_to_ids.set(content, id); - this.model.vocab[id] = content; + this.model.tokens_to_ids.set(token.content, token.id); + this.model.vocab[token.id] = token.content; - if (addedToken.special) { - this.special_tokens.push(content); - this.all_special_ids.push(id); + if (token.special) { + this.special_tokens.push(token.content); + this.all_special_ids.push(token.id); } } // Update additional_special_tokens - this.special_tokens.push(...(tokenizerConfig.additional_special_tokens ?? [])); + this.additional_special_tokens = tokenizerConfig.additional_special_tokens ?? []; + this.special_tokens.push(...this.additional_special_tokens); this.special_tokens = [...new Set(this.special_tokens)]; // Remove duplicates if (this.decoder) { @@ -2253,17 +2328,17 @@ export class PreTrainedTokenizer extends Callable { this.added_tokens_regex = this.added_tokens.length > 0 ? new RegExp( - '(' + this.added_tokens.map(escapeRegExp).join('|') + ')' + this.added_tokens.map(x => `${x.lstrip ? '\\s*' : ''}(${escapeRegExp(x.content)})${x.rstrip ? '\\s*' : ''}`).join('|') ) : null; // Set mask token if present (otherwise will be undefined, which is fine) - this.mask_token = this.getToken(tokenizerConfig, 'mask_token'); + this.mask_token = this.getToken('mask_token'); this.mask_token_id = this.model.tokens_to_ids.get(this.mask_token); - this.pad_token = this.getToken(tokenizerConfig, 'pad_token', 'eos_token'); + this.pad_token = this.getToken('pad_token', 'eos_token'); this.pad_token_id = this.model.tokens_to_ids.get(this.pad_token); - this.sep_token = this.getToken(tokenizerConfig, 'sep_token'); + this.sep_token = this.getToken('sep_token'); this.sep_token_id = this.model.tokens_to_ids.get(this.sep_token); this.unk_token = this.getToken(tokenizerConfig, 'unk_token'); @@ -2279,6 +2354,11 @@ export class PreTrainedTokenizer extends Callable { // TODO allow user to change this this.padding_side = 'right'; + + this.legacy = false; + + this.chat_template = tokenizerConfig.chat_template ?? null; + this._compiled_template_cache = new Map(); } /** @@ -2287,9 +2367,9 @@ export class PreTrainedTokenizer extends Callable { * @returns {string|null} The value associated with the first matching key, or null if no match is found. * @throws {Error} If an object is found for a matching key and its __type property is not "AddedToken". */ - getToken(tokenizerConfig, ...keys) { + getToken(...keys) { for (let key of keys) { - let item = tokenizerConfig[key]; + let item = this._tokenizer_config[key]; if (!item) continue; @@ -2310,7 +2390,7 @@ export class PreTrainedTokenizer extends Callable { * Loads a pre-trained tokenizer from the given `pretrained_model_name_or_path`. * * @param {string} pretrained_model_name_or_path The path to the pre-trained tokenizer. - * @param {import('./utils/hub.js').PretrainedOptions} options Additional options for loading the tokenizer. + * @param {PretrainedTokenizerOptions} options Additional options for loading the tokenizer. * * @throws {Error} Throws an error if the tokenizer.json or tokenizer_config.json files are not found in the `pretrained_model_name_or_path`. * @returns {Promise} A new instance of the `PreTrainedTokenizer` class. @@ -2321,6 +2401,7 @@ export class PreTrainedTokenizer extends Callable { cache_dir = null, local_files_only = false, revision = 'main', + legacy = null, } = {}) { let info = await loadTokenizer(pretrained_model_name_or_path, { @@ -2329,6 +2410,7 @@ export class PreTrainedTokenizer extends Callable { cache_dir, local_files_only, revision, + legacy, }) // @ts-ignore @@ -2525,8 +2607,10 @@ export class PreTrainedTokenizer extends Callable { // First, we take care of special tokens. Needed to avoid issues arising from // normalization and/or pretokenization (which may not preserve special tokens) const sections = this.added_tokens_regex ? text.split(this.added_tokens_regex).filter(x => x) : [text]; - let tokens = sections.map(x => { - if (this.added_tokens.includes(x)) { + + const tokens = sections.map((x, section_index) => { + const addedToken = this.added_tokens.find(t => t.content === x); + if (addedToken !== undefined) { // Ignore added tokens return x } else { @@ -2541,9 +2625,11 @@ export class PreTrainedTokenizer extends Callable { x = this.normalizer(x); } - let sectionTokens = (this.pre_tokenizer !== null) ? this.pre_tokenizer(x) : [x]; + const sectionTokens = (this.pre_tokenizer !== null) ? this.pre_tokenizer(x, { + section_index, + }) : [x]; - let tokens = this.model(sectionTokens); + const tokens = this.model(sectionTokens); return tokens; } @@ -2658,6 +2744,127 @@ export class PreTrainedTokenizer extends Callable { return decoded; } + get default_chat_template() { + if (!this._warned_about_chat_template) { + console.warn( + "No chat template is defined for this tokenizer - using a default chat template " + + "that implements the ChatML format. If the default is not appropriate for " + + "your model, please set `tokenizer.chat_template` to an appropriate template. " + + "See https://huggingface.co/docs/transformers/main/chat_templating for more information." + ) + this._warned_about_chat_template = true; // TODO move to logger.warning_once() + } + + return this._default_chat_template; + } + + /** + * @typedef {Object} Message + * @property {string} role The role of the message (e.g., "user" or "assistant" or "system"). + * @property {string} content The content of the message. + */ + + /** + * Converts a list of message objects with `"role"` and `"content"` keys to a list of token + * ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to + * determine the format and control tokens to use when converting. When chat_template is None, it will fall back + * to the default_chat_template specified at the class level. + * + * See [here](https://huggingface.co/docs/transformers/chat_templating) for more information. + * + * **Example:** Applying a chat template to a conversation. + * + * ```javascript + * import { AutoTokenizer } from "@xenova/transformers"; + * + * const tokenizer = await AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1"); + * + * const chat = [ + * { "role": "user", "content": "Hello, how are you?" }, + * { "role": "assistant", "content": "I'm doing great. How can I help you today?" }, + * { "role": "user", "content": "I'd like to show off how chat templating works!" }, + * ] + * + * const text = await tokenizer.apply_chat_template(chat, { tokenize: false }); + * // "[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" + * + * const input_ids = await tokenizer.apply_chat_template(chat, { tokenize: true, return_tensor: false }); + * // [1, 733, 16289, 28793, 22557, 28725, 910, 460, 368, 28804, 733, 28748, 16289, 28793, 28737, 28742, 28719, 2548, 1598, 28723, 1602, 541, 315, 1316, 368, 3154, 28804, 2, 28705, 733, 16289, 28793, 315, 28742, 28715, 737, 298, 1347, 805, 910, 10706, 5752, 1077, 3791, 28808, 733, 28748, 16289, 28793] + * ``` + * + * @param {Message[]} conversation A list of message objects with `"role"` and `"content"` keys. + * @param {Object} options An optional object containing the following properties: + * @param {string} [options.chat_template=null] A Jinja template to use for this conversion. If + * this is not passed, the model's default chat template will be used instead. + * @param {boolean} [options.add_generation_prompt=false] Whether to end the prompt with the token(s) that indicate + * the start of an assistant message. This is useful when you want to generate a response from the model. + * Note that this argument will be passed to the chat template, and so it must be supported in the + * template for this argument to have any effect. + * @param {boolean} [options.tokenize=true] Whether to tokenize the output. If false, the output will be a string. + * @param {boolean} [options.padding=false] Whether to pad sequences to the maximum length. Has no effect if tokenize is false. + * @param {boolean} [options.truncation=false] Whether to truncate sequences to the maximum length. Has no effect if tokenize is false. + * @param {number} [options.max_length=null] Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is false. + * If not specified, the tokenizer's `max_length` attribute will be used as a default. + * @param {boolean} [options.return_tensor=true] Whether to return the output as a Tensor or an Array. Has no effect if tokenize is false. + * @returns {Promise} A promise that resolves to the tokenized output. + */ + async apply_chat_template(conversation, { + chat_template = null, + add_generation_prompt = false, + tokenize = true, + padding = false, + truncation = false, + max_length = null, + return_tensor = true, + } = {}) { + + chat_template ??= this.chat_template ?? this.default_chat_template; + + // Compilation function uses a cache to avoid recompiling the same template + let compiledTemplate = this._compiled_template_cache.get(chat_template); + if (compiledTemplate === undefined) { + // Dynamically load the `@huggingface/jinja` library. Since this is a peer dependency + // (i.e., must be installed separately), an error is thrown if it is not installed. + let Template; + try { + Template = (await import('@huggingface/jinja')).Template; + } catch (e) { + throw new Error( + `apply_chat_template requires '@huggingface/jinja' to be installed. ` + + `You can install it with \`npm install @huggingface/jinja\`.` + ) + } + compiledTemplate = new Template(chat_template); + this._compiled_template_cache.set(chat_template, compiledTemplate); + } + + const special_tokens_map = Object.create(null); + for (const key of SPECIAL_TOKEN_ATTRIBUTES) { + const value = this.getToken(key); + if (value) { + special_tokens_map[key] = value; + } + } + + const rendered = compiledTemplate.render({ + messages: conversation, + add_generation_prompt: add_generation_prompt, + + ...special_tokens_map, + }); + + if (tokenize) { + return this._call(rendered, { + add_special_tokens: false, + padding, + truncation, + max_length, + return_tensor, + }).input_ids; + } + + return rendered; + } } /** @@ -2767,7 +2974,9 @@ export class ElectraTokenizer extends PreTrainedTokenizer { } export class T5Tokenizer extends PreTrainedTokenizer { } -export class GPT2Tokenizer extends PreTrainedTokenizer { } +export class GPT2Tokenizer extends PreTrainedTokenizer { + _default_chat_template = `{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}` +} export class BartTokenizer extends PreTrainedTokenizer { } export class MBartTokenizer extends PreTrainedTokenizer { constructor(tokenizerJSON, tokenizerConfig) { @@ -2793,7 +3002,8 @@ export class MBart50Tokenizer extends MBartTokenizer { } // NOTE: extends MBartT export class RobertaTokenizer extends PreTrainedTokenizer { } -export class BloomTokenizer extends PreTrainedTokenizer { +export class BloomTokenizer extends GPT2Tokenizer { // NOTE: `GPT2Tokenizer` to get the correct chat template + constructor(tokenizerJSON, tokenizerConfig) { // Override the default (invalid) regex of the pretokenizer. // For more information, see https://github.com/xenova/transformers.js/issues/94 @@ -2805,8 +3015,62 @@ export class BloomTokenizer extends PreTrainedTokenizer { super(tokenizerJSON, tokenizerConfig); } } -export class LlamaTokenizer extends PreTrainedTokenizer { } -export class CodeLlamaTokenizer extends PreTrainedTokenizer { } + +const SPIECE_UNDERLINE = "▁"; + +export class LlamaTokenizer extends PreTrainedTokenizer { + _default_chat_template = `{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\n' + system_message + '\n<>\n\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\n' + content.strip() + '\n<>\n\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}` + + DEFAULT_SYSTEM_PROMPT = + "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your " + + "answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure " + + "that your responses are socially unbiased and positive in nature.\n\n" + + "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not " + + "correct. If you don't know the answer to a question, please don't share false information." + + constructor(tokenizerJSON, tokenizerConfig) { + super(tokenizerJSON, tokenizerConfig); + this.use_default_system_prompt = tokenizerConfig.use_default_system_prompt ?? false; + + this.legacy = tokenizerConfig.legacy ?? true; + if (!this.legacy) { + // See https://github.com/huggingface/transformers/pull/24565 for more information + this.normalizer = null; + this.pre_tokenizer = new MetaspacePreTokenizer({ + replacement: SPIECE_UNDERLINE, + add_prefix_space: true, + prepend_scheme: "first", + }); + } + } + + /** + * Helper function to handle legacy encoding of SPM tokenizers. + * Adapted from https://github.com/huggingface/transformers/blob/e6dcf8abd6f65bb4b6dfc1831b20d9ba49ce00e2/src/transformers/models/t5/tokenization_t5.py#L374-L387 + * @param {string} text The text to encode. + * @returns {string[]} The encoded tokens. + */ + _encode_text(text) { + if (text === null) return null; + + if (this.legacy || text.length === 0) { + return super._encode_text(text); + } + + let tokens = super._encode_text(SPIECE_UNDERLINE + text.replaceAll(SPIECE_UNDERLINE, " ")); + if (tokens.length > 1 && tokens[0] === SPIECE_UNDERLINE && this.special_tokens.includes(tokens[1])) { + tokens = tokens.slice(1); + } + return tokens; + } + + get default_chat_template() { + return super.default_chat_template + .replaceAll('USE_DEFAULT_PROMPT', this.use_default_system_prompt ? 'true' : 'false') + .replaceAll('DEFAULT_SYSTEM_MESSAGE', this.DEFAULT_SYSTEM_PROMPT.replaceAll("\n", "\\n").replaceAll("'", "\\'")); + } +} +export class CodeLlamaTokenizer extends LlamaTokenizer { } // NOTE: `LlamaTokenizer` to get the correct chat template export class XLMRobertaTokenizer extends PreTrainedTokenizer { } export class MPNetTokenizer extends PreTrainedTokenizer { } @@ -3064,6 +3328,7 @@ const WHISPER_TO_LANGUAGE_CODE_MAPPING = new Map([ * @extends PreTrainedTokenizer */ export class WhisperTokenizer extends PreTrainedTokenizer { + _default_chat_template = `{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}`; /** * Decodes automatic speech recognition (ASR) sequences. @@ -3854,8 +4119,10 @@ export class MarianTokenizer extends PreTrainedTokenizer { export class Wav2Vec2CTCTokenizer extends PreTrainedTokenizer { } -export class BlenderbotTokenizer extends PreTrainedTokenizer { } -export class BlenderbotSmallTokenizer extends PreTrainedTokenizer { } +export class BlenderbotTokenizer extends PreTrainedTokenizer { + _default_chat_template = `{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}`; +} +export class BlenderbotSmallTokenizer extends BlenderbotTokenizer { } // NOTE `BlenderbotTokenizer` to get the correct chat template export class SpeechT5Tokenizer extends PreTrainedTokenizer { } @@ -3924,7 +4191,7 @@ export class AutoTokenizer { * Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a * user or organization name, like `dbmdz/bert-base-german-cased`. * - A path to a *directory* containing tokenizer files, e.g., `./my_model_directory/`. - * @param {import('./utils/hub.js').PretrainedOptions} options Additional options for loading the tokenizer. + * @param {PretrainedTokenizerOptions} options Additional options for loading the tokenizer. * * @returns {Promise} A new instance of the PreTrainedTokenizer class. */ @@ -3935,6 +4202,7 @@ export class AutoTokenizer { cache_dir = null, local_files_only = false, revision = 'main', + legacy = null, } = {}) { let [tokenizerJSON, tokenizerConfig] = await loadTokenizer(pretrained_model_name_or_path, { @@ -3944,10 +4212,11 @@ export class AutoTokenizer { cache_dir, local_files_only, revision, + legacy, }) // Some tokenizers are saved with the "Fast" suffix, so we remove that if present. - let tokenizerName = tokenizerConfig.tokenizer_class.replace(/Fast$/, ''); + let tokenizerName = tokenizerConfig.tokenizer_class?.replace(/Fast$/, '') ?? 'PreTrainedTokenizer'; let cls = this.TOKENIZER_CLASS_MAPPING[tokenizerName]; if (!cls) { diff --git a/tests/generate_tests.py b/tests/generate_tests.py index c65ea79a7..b7af5ed1d 100644 --- a/tests/generate_tests.py +++ b/tests/generate_tests.py @@ -3,6 +3,7 @@ import json import os +from itertools import product from transformers import AutoTokenizer, AutoConfig import numpy as np @@ -15,12 +16,18 @@ 'tiiuae/falcon-7b', ], "llama": [ - 'hf-internal-testing/llama-tokenizer', + 'hf-internal-testing/llama-tokenizer', # Special tokens: normalized=true + 'Xenova/llama2-tokenizer', # Special tokens: normalized=false + 'Xenova/llama2-chat-tokenizer', # Special tokens: normalized=false 'hf-internal-testing/llama-code-tokenizer', ], 'mpt': [ 'mosaicml/mpt-7b', ], + 't5': [ + # TODO: Add back when https://github.com/huggingface/transformers/issues/26318 is fixed + # 'Xenova/t5-tokenizer-new', + ], } MODELS_TO_IGNORE = [ @@ -37,6 +44,9 @@ TOKENIZERS_TO_IGNORE = [ # TODO: remove when https://github.com/huggingface/transformers/pull/25478 is merged 'facebook/m2m100_418M', + + # TODO: remove when https://github.com/huggingface/transformers/issues/28096 is addressed + 'RajuKandasamy/tamillama_tiny_30m', ] MAX_TESTS = { @@ -65,19 +75,14 @@ "\u0079\u006F\u0075\u2026\u00A0\u00A0\u0079\u006F\u0075\u2026\u00A0\u00A0", "▁This ▁is ▁a ▁test ▁.", "weird \uFF5E edge \uFF5E case", + + # SentencePiece-specific test cases + "\n", + " test ", + "test", ], - "custom": { - "facebook/blenderbot_small-90M": [ - # Test special tokens - "__start__hello world__end__", - # The original (python) tokenizer simply joins by spaces (regardless of special tokens or not) - "__start__ hey __end__" # --> ... --> "__start__ hey __end__" - "__start__hey __end__" # --> ... --> "__start__ hey __end__" - ], - "tiiuae/falcon-7b": [ - "12 and 123 and 1234", # Special case for splitting on 3 numbers - ], - "hf-internal-testing/llama-tokenizer": [ + "custom_by_model_type": { + "llama": [ # Additional test-cases for the Llama tokenizer, adapted from # https://github.com/belladoreai/llama-tokenizer-js/blob/master/llama-tokenizer.js#L381-L452 "grabbed", @@ -86,6 +91,7 @@ "\n", " \n", " tabs out here", + "\n\t\n", "ax\n####\nboo", "镇", "🦙", @@ -105,7 +111,19 @@ "the 20th century, in the United States and Canada.[5] In Aymara mythology, llamas are important beings. " \ "The Heavenly Llama is said to drink water from the ocean and urinates as it rains.[6] According to " \ "Aymara eschatology, llamas will return to the water springs and lagoons where they come from at the " \ - "end of time.[6]" + "end of time.[6]", + ] + }, + "custom": { + "facebook/blenderbot_small-90M": [ + # Test special tokens + "__start__hello world__end__", + # The original (python) tokenizer simply joins by spaces (regardless of special tokens or not) + "__start__ hey __end__" # --> ... --> "__start__ hey __end__" + "__start__hey __end__" # --> ... --> "__start__ hey __end__" + ], + "tiiuae/falcon-7b": [ + "12 and 123 and 1234", # Special case for splitting on 3 numbers ], "InstaDeepAI/nucleotide-transformer-500m-human-ref": [ # Actual protein sequences @@ -115,9 +133,73 @@ # Special tokens "", ], + + "distil-whisper/distil-small.en": [ + " <|startoftranscript|> <|en|> ", # Tests lstrip+rstrip + ], + + "Xenova/t5-tokenizer-new": [ + # Tests the new T5 tokenizer, which uses a different prepend_scheme for its pre_tokenizer: + # tokenizer._tokenizer.pre_tokenizer = Metaspace(add_prefix_space = True, replacement = "▁", prepend_scheme = "first") + # See https://github.com/huggingface/transformers/pull/26678 for more information. + # - Old (incorrect): ['▁Hey', '▁', '', '▁', '.', '▁how', '▁are', '▁you'] + # - New (correct): ['▁Hey', '▁', '', '.', '▁how', '▁are', '▁you'] + "Hey . how are you", + ], }, } +CHAT_MESSAGES_EXAMPLES = { + 'basic': [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing great. How can I help you today?"}, + {"role": "user", "content": "I'd like to show off how chat templating works!"}, + ], + + 'system': [ + {"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}, + {"role": "user", "content": "How many helicopters can a human eat in one sitting?"}, + ], + + 'system + assistant': [ + {"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}, + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing great. How can I help you today?"}, + {"role": "user", "content": "I'd like to show off how chat templating works!"}, + ], +} + +TOKENIZERS_WITH_CHAT_TEMPLATES = { + # https://huggingface.co/docs/transformers/main/en/chat_templating + 'Xenova/blenderbot-400M-distill': [ + 'basic', + ], + + 'mistralai/Mistral-7B-Instruct-v0.1': [ + 'basic', + ], + + 'HuggingFaceH4/zephyr-7b-beta': [ + 'system', + ], + + 'Xenova/llama-tokenizer': [ + 'basic', + 'system', + 'system + assistant', + ], + 'Xenova/llama2-tokenizer': [ + 'basic', + 'system', + 'system + assistant', + ], + 'Xenova/llama2-chat-tokenizer': [ + 'basic', + 'system', + 'system + assistant', + ], +} + FLATTENED_SUPPORTED_MODELS = [ (model_type, [ @@ -128,7 +210,7 @@ def generate_tokenizer_tests(): - results = {} + tokenization_results = {} tokenizers_to_test = FLATTENED_SUPPORTED_MODELS + \ list(ADDITIONAL_TOKENIZERS_TO_TEST.items()) @@ -139,6 +221,9 @@ def generate_tokenizer_tests(): if model_type in MAX_TESTS: tokenizer_names = tokenizer_names[:MAX_TESTS[model_type]] + custom_by_model_type_texts = TOKENIZER_TEST_DATA["custom_by_model_type"].get( + model_type, []) + print(f'Generating tests for {model_type}') for tokenizer_name in tokenizer_names: if tokenizer_name in TOKENIZERS_TO_IGNORE: @@ -148,7 +233,32 @@ def generate_tokenizer_tests(): try: # Load tokenizer - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + if model_type == 'llama': + # As of 17/12/2023, there are a few issues with the Llama tokenizers in transformers. + # (1) Encoding with fast tokenizer adds whitespace after speical tokens: + # - https://github.com/huggingface/transformers/issues/25881 + # - https://github.com/huggingface/transformers/issues/26318 + # - https://github.com/huggingface/transformers/issues/26455 + # - https://github.com/huggingface/transformers/issues/27544 + # (2) Decoding with slow tokenizer adds whitespace after special tokens: + # - https://github.com/huggingface/transformers/issues/25073 + # + # So for now, we mix and match the tokenizers: + # i.e., use the fast tokenizer for encoding, and the slow tokenizer for decoding. + # TODO: remove when the above issues are fixed: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, + use_fast=False, + ) + decoder_tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, + use_fast=True, + ) + + else: + decoder_tokenizer = tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name) + except (KeyError, EnvironmentError): # If a KeyError/EnvironmentError is raised from the AutoTokenizer, it # means the model does not use a tokenizer (e.g., vision models) @@ -167,7 +277,7 @@ def generate_tokenizer_tests(): tokenizer_name, []) # Run tokenizer on test cases - for text in shared_texts + custom_texts: + for text in shared_texts + custom_texts + custom_by_model_type_texts: # TODO: add with_pair option try: encoded = tokenizer(text).data @@ -175,9 +285,9 @@ def generate_tokenizer_tests(): # Ignore testing tokenizers which fail in the python library continue - decoded_with_special = tokenizer.decode( + decoded_with_special = decoder_tokenizer.decode( encoded["input_ids"], skip_special_tokens=False) - decoded_without_special = tokenizer.decode( + decoded_without_special = decoder_tokenizer.decode( encoded["input_ids"], skip_special_tokens=True) tokenizer_results.append(dict( @@ -188,9 +298,40 @@ def generate_tokenizer_tests(): )) if tokenizer_results: - results[tokenizer_name] = tokenizer_results + tokenization_results[tokenizer_name] = tokenizer_results - return results + template_results = {} + + for tokenizer_id in TOKENIZERS_WITH_CHAT_TEMPLATES: + print(f'Generating chat templates for {tokenizer_id}') + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_id, + + # TODO: Remove once https://github.com/huggingface/transformers/pull/26678 is fixed + use_fast='llama' not in tokenizer_id, + ) + tokenizer_results = [] + for key in TOKENIZERS_WITH_CHAT_TEMPLATES[tokenizer_id]: + messages = CHAT_MESSAGES_EXAMPLES[key] + + for add_generation_prompt, tokenize in product([True, False], [True, False]): + tokenizer_results.append(dict( + messages=messages, + add_generation_prompt=add_generation_prompt, + tokenize=tokenize, + target=tokenizer.apply_chat_template( + messages, + add_generation_prompt=add_generation_prompt, + tokenize=tokenize, + ), + )) + + template_results[tokenizer_id] = tokenizer_results + + return dict( + tokenization=tokenization_results, + templates=template_results, + ) def generate_config_tests(): @@ -214,10 +355,10 @@ def generate_config_tests(): return results -ARRAY_SIZES = sorted(set([2 ** i for i in range(1, 10)]) \ - | set([3 ** i for i in range(1, 8)]) \ - | set([5 ** i for i in range(1, 6)]) \ - | set([7 ** i for i in range(1, 4)])) +ARRAY_SIZES = sorted(set([2 ** i for i in range(1, 10)]) + | set([3 ** i for i in range(1, 8)]) + | set([5 ** i for i in range(1, 6)]) + | set([7 ** i for i in range(1, 4)])) def serialize_complex_array(arr): @@ -234,7 +375,8 @@ def generate_fft_tests(): for complex in [False, True]: serialize_fn = serialize_complex_array if complex else serialize_real_array for size in ARRAY_SIZES: - arr = np.random.randn(size).astype(np.complex64 if complex else np.float64) + arr = np.random.randn(size).astype( + np.complex64 if complex else np.float64) if complex: arr += np.random.randn(size) * 1j tests[f"fft_{size}_{'complex' if complex else 'real'}"] = { @@ -263,6 +405,7 @@ def main(): fft_tests = generate_fft_tests() with open(os.path.join(data_dir, "fft_tests.json"), "w", encoding="utf-8") as fp: json.dump(fft_tests, fp) - + + if __name__ == "__main__": main() diff --git a/tests/tensor.test.js b/tests/tensor.test.js index 0d328a9ef..93d9fd0ad 100644 --- a/tests/tensor.test.js +++ b/tests/tensor.test.js @@ -107,7 +107,7 @@ describe('Tensor operations', () => { describe('mean', () => { it('should calculate mean', async () => { const t1 = new Tensor('float32', [1, 2, 3, 4, 5, 6], [2, 3, 1]); - + const target = new Tensor('float32', [3.5], []); const target0 = new Tensor('float32', [2.5, 3.5, 4.5], [3, 1]); @@ -122,7 +122,7 @@ describe('Tensor operations', () => { let avg1 = mean(t1, 1); compare(avg1, target1, 1e-3); - + let avg2 = mean(t1, 2); compare(avg2, target2, 1e-3); diff --git a/tests/tokenizers.test.js b/tests/tokenizers.test.js index 2d4dfe683..f0147c7c1 100644 --- a/tests/tokenizers.test.js +++ b/tests/tokenizers.test.js @@ -7,11 +7,12 @@ import { compare } from './test_utils.js'; // Load test data generated by the python tests // TODO do this dynamically? -let testsData = await (await getFile('./tests/data/tokenizer_tests.json')).json() +const { tokenization, templates } = await (await getFile('./tests/data/tokenizer_tests.json')).json() -describe('Tokenizers', () => { +// Dynamic tests to ensure transformers.js (JavaScript) matches transformers (Python) +describe('Tokenizers (dynamic)', () => { - for (let [tokenizerName, tests] of Object.entries(testsData)) { + for (let [tokenizerName, tests] of Object.entries(tokenization)) { it(tokenizerName, async () => { let tokenizer = await AutoTokenizer.from_pretrained(m(tokenizerName)); @@ -24,8 +25,7 @@ describe('Tokenizers', () => { }); // Add the input text to the encoded object for easier debugging - encoded.input = test.input; - test.encoded.input = test.input; + test.encoded.input = encoded.input = test.input; expect(encoded).toEqual(test.encoded); @@ -40,6 +40,103 @@ describe('Tokenizers', () => { } }); +// Tests to ensure that no matter what, the correct tokenization is returned. +// This is necessary since there are sometimes bugs in the transformers library. +describe('Tokenizers (hard-coded)', () => { + const TESTS = { + 'Xenova/llama-tokenizer': [ // Test legacy compatibility + { + // legacy unset => legacy=true + // NOTE: While incorrect, it is necessary to match legacy behaviour + data: { + "\n": [1, 29871, 13], + }, + legacy: null, + }, + { + // override legacy=true (same results as above) + data: { + "\n": [1, 29871, 13], + }, + legacy: true, + }, + { + // override legacy=false (fixed results) + data: { + "\n": [1, 13], + }, + legacy: false, + } + ], + + 'Xenova/llama-tokenizer_new': [ // legacy=false + { + data: { + " 1 2 3 4 ": [259, 2, 29871, 29896, 259, 29906, 1678, 29941, 268, 29946, 1678], + "\n": [1, 13], + "test": [2, 1688, 2], + " test ": [259, 2, 1243, 29871, 2, 29871], + "A\n'll": [319, 13, 29915, 645], + "Hey . how are you": [18637, 29871, 2, 29889, 920, 526, 366], + " Hi Hello ": [259, 6324, 29871, 15043, 259], + }, + reversible: true, + legacy: null, + }, + { // override legacy=true (incorrect results, but necessary to match legacy behaviour) + data: { + "\n": [1, 29871, 13], + }, + legacy: true, + }, + ], + + // legacy=false + 'Xenova/t5-tokenizer-new': [ + { + data: { + // https://github.com/huggingface/transformers/pull/26678 + // ['▁Hey', '▁', '', '.', '▁how', '▁are', '▁you'] + "Hey . how are you": [9459, 3, 1, 5, 149, 33, 25], + }, + reversible: true, + legacy: null, + }, + { + data: { + "\n": [1, 3], + "A\n'll": [71, 3, 31, 195], + }, + reversible: false, + legacy: null, + } + ], + } + + // Re-use the same tests for the llama2 tokenizer + TESTS['Xenova/llama2-tokenizer'] = TESTS['Xenova/llama-tokenizer_new']; + + for (const [tokenizerName, test_data] of Object.entries(TESTS)) { + + it(tokenizerName, async () => { + for (const { data, reversible, legacy } of test_data) { + const tokenizer = await AutoTokenizer.from_pretrained(m(tokenizerName), { legacy }); + + for (const [text, expected] of Object.entries(data)) { + const token_ids = tokenizer.encode(text, null, { add_special_tokens: false }); + expect(token_ids).toEqual(expected); + + // If reversible, test that decoding produces the original text + if (reversible) { + const decoded = tokenizer.decode(token_ids); + expect(decoded).toEqual(text); + } + } + } + }, MAX_TEST_EXECUTION_TIME); + } +}); + describe('Edge cases', () => { it('should not crash when encoding a very long string', async () => { let tokenizer = await AutoTokenizer.from_pretrained('Xenova/t5-small'); @@ -81,3 +178,94 @@ describe('Extra decoding tests', () => { }, MAX_TEST_EXECUTION_TIME); }); + +describe('Chat templates', () => { + it('should generate a chat template', async () => { + const tokenizer = await AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1"); + + const chat = [ + { "role": "user", "content": "Hello, how are you?" }, + { "role": "assistant", "content": "I'm doing great. How can I help you today?" }, + { "role": "user", "content": "I'd like to show off how chat templating works!" }, + ] + + const text = await tokenizer.apply_chat_template(chat, { tokenize: false }); + + expect(text).toEqual("[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]"); + + const input_ids = await tokenizer.apply_chat_template(chat, { tokenize: true, return_tensor: false }); + compare(input_ids, [1, 733, 16289, 28793, 22557, 28725, 910, 460, 368, 28804, 733, 28748, 16289, 28793, 28737, 28742, 28719, 2548, 1598, 28723, 1602, 541, 315, 1316, 368, 3154, 28804, 2, 28705, 733, 16289, 28793, 315, 28742, 28715, 737, 298, 1347, 805, 910, 10706, 5752, 1077, 3791, 28808, 733, 28748, 16289, 28793]) + }); + + it('should support user-defined chat template', async () => { + const tokenizer = await AutoTokenizer.from_pretrained("Xenova/llama-tokenizer"); + + const chat = [ + { role: 'user', content: 'Hello, how are you?' }, + { role: 'assistant', content: "I'm doing great. How can I help you today?" }, + { role: 'user', content: "I'd like to show off how chat templating works!" }, + ] + + // https://discuss.huggingface.co/t/issue-with-llama-2-chat-template-and-out-of-date-documentation/61645/3 + const chat_template = ( + "{% if messages[0]['role'] == 'system' %}" + + "{% set loop_messages = messages[1:] %}" + // Extract system message if it's present + "{% set system_message = messages[0]['content'] %}" + + "{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}" + + "{% set loop_messages = messages %}" + // Or use the default system message if the flag is set + "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}" + + "{% else %}" + + "{% set loop_messages = messages %}" + + "{% set system_message = false %}" + + "{% endif %}" + + "{% if loop_messages|length == 0 and system_message %}" + // Special handling when only sys message present + "{{ bos_token + '[INST] <>\\n' + system_message + '\\n<>\\n\\n [/INST]' }}" + + "{% endif %}" + + "{% for message in loop_messages %}" + // Loop over all non-system messages + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + + "{% endif %}" + + "{% if loop.index0 == 0 and system_message != false %}" + // Embed system message in first message + "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}" + + "{% else %}" + + "{% set content = message['content'] %}" + + "{% endif %}" + + "{% if message['role'] == 'user' %}" + // After all of that, handle messages/roles in a fairly normal way + "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" + + "{% elif message['role'] == 'system' %}" + + "{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}" + + "{% elif message['role'] == 'assistant' %}" + + "{{ ' ' + content.strip() + ' ' + eos_token }}" + + "{% endif %}" + + "{% endfor %}" + ) + .replaceAll('USE_DEFAULT_PROMPT', true) + .replaceAll('DEFAULT_SYSTEM_MESSAGE', 'You are a helpful, respectful and honest assistant.'); + + const text = await tokenizer.apply_chat_template(chat, { tokenize: false, return_tensor: false, chat_template }); + + expect(text).toEqual("[INST] <>\nYou are a helpful, respectful and honest assistant.\n<>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]"); + + // TODO: Add test for token_ids once bug in transformers is fixed. + }); + + // Dynamically-generated tests + for (const [tokenizerName, tests] of Object.entries(templates)) { + + it(tokenizerName, async () => { + // NOTE: not m(...) here + // TODO: update this? + const tokenizer = await AutoTokenizer.from_pretrained(tokenizerName); + + for (let { messages, add_generation_prompt, tokenize, target } of tests) { + + const generated = await tokenizer.apply_chat_template(messages, { + tokenize, + add_generation_prompt, + return_tensor: false, + }); + expect(generated).toEqual(target) + } + }); + } +});