From 836c0afee3f11fdc5075601824dacc7bd91308d9 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Wed, 7 Aug 2024 17:25:15 +0200 Subject: [PATCH] Add `apply_chat_template` default parameters unit test --- tests/tokenizers.test.js | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/tokenizers.test.js b/tests/tokenizers.test.js index d571d3178..96e1bcd5e 100644 --- a/tests/tokenizers.test.js +++ b/tests/tokenizers.test.js @@ -488,6 +488,35 @@ describe('Chat templates', () => { // TODO: Add test for token_ids once bug in transformers is fixed. }); + it('should support default parameters', async () => { + const tokenizer = await AutoTokenizer.from_pretrained("Xenova/Meta-Llama-3.1-Tokenizer"); + + // Example adapted from https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct#tool-use-with-transformers + const chat = [ + { "role": "system", "content": "You are a bot that responds to weather queries." }, + { "role": "user", "content": "Hey, what's the temperature in Paris right now?" } + ] + const tools = [ + { 'type': 'function', 'function': { 'name': 'get_current_temperature', 'description': 'Get the current temperature at a location.', 'parameters': { 'type': 'object', 'properties': { 'location': { 'type': 'string', 'description': 'The location to get the temperature for, in the format "City, Country"' } }, 'required': ['location'] }, 'return': { 'type': 'number', 'description': 'The current temperature at the specified location in the specified units, as a float.' } } }, + ] + + { // `tools` unset (will default to `null`) + const text = tokenizer.apply_chat_template(chat, { tokenize: false }); + expect(text).toEqual("<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYou are a bot that responds to weather queries.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHey, what's the temperature in Paris right now?<|eot_id|>"); + + const input_ids = tokenizer.apply_chat_template(chat, { tokenize: true, return_tensor: false }); + compare(input_ids, [128000, 128006, 9125, 128007, 271, 38766, 1303, 33025, 2696, 25, 6790, 220, 2366, 18, 198, 15724, 2696, 25, 220, 1627, 10263, 220, 2366, 19, 271, 2675, 527, 264, 11164, 430, 31680, 311, 9282, 20126, 13, 128009, 128006, 882, 128007, 271, 19182, 11, 1148, 596, 279, 9499, 304, 12366, 1314, 1457, 30, 128009]) + } + + { // `tools` set + const text = tokenizer.apply_chat_template(chat, { tools, tokenize: false }); + expect(text).toEqual("<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYou are a bot that responds to weather queries.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"type\": \"function\",\n \"function\": {\n \"name\": \"get_current_temperature\",\n \"description\": \"Get the current temperature at a location.\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"location\": {\n \"type\": \"string\",\n \"description\": \"The location to get the temperature for, in the format \\\"City, Country\\\"\"\n }\n },\n \"required\": [\n \"location\"\n ]\n },\n \"return\": {\n \"type\": \"number\",\n \"description\": \"The current temperature at the specified location in the specified units, as a float.\"\n }\n }\n}\n\nHey, what's the temperature in Paris right now?<|eot_id|>"); + + const input_ids = tokenizer.apply_chat_template(chat, { tools, tokenize: true, return_tensor: false }); + compare(input_ids, [128000, 128006, 9125, 128007, 271, 13013, 25, 6125, 27993, 198, 38766, 1303, 33025, 2696, 25, 6790, 220, 2366, 18, 198, 15724, 2696, 25, 220, 1627, 10263, 220, 2366, 19, 271, 2675, 527, 264, 11164, 430, 31680, 311, 9282, 20126, 13, 128009, 128006, 882, 128007, 271, 22818, 279, 2768, 5865, 11, 4587, 6013, 449, 264, 4823, 369, 264, 734, 1650, 449, 1202, 6300, 6105, 430, 1888, 11503, 279, 2728, 10137, 382, 66454, 304, 279, 3645, 5324, 609, 794, 734, 836, 11, 330, 14105, 794, 11240, 315, 5811, 836, 323, 1202, 907, 7966, 5519, 539, 1005, 7482, 382, 517, 262, 330, 1337, 794, 330, 1723, 761, 262, 330, 1723, 794, 341, 286, 330, 609, 794, 330, 456, 11327, 54625, 761, 286, 330, 4789, 794, 330, 1991, 279, 1510, 9499, 520, 264, 3813, 10560, 286, 330, 14105, 794, 341, 310, 330, 1337, 794, 330, 1735, 761, 310, 330, 13495, 794, 341, 394, 330, 2588, 794, 341, 504, 330, 1337, 794, 330, 928, 761, 504, 330, 4789, 794, 330, 791, 3813, 311, 636, 279, 9499, 369, 11, 304, 279, 3645, 7393, 13020, 11, 14438, 2153, 702, 394, 457, 310, 1173, 310, 330, 6413, 794, 2330, 394, 330, 2588, 702, 310, 5243, 286, 1173, 286, 330, 693, 794, 341, 310, 330, 1337, 794, 330, 4174, 761, 310, 330, 4789, 794, 330, 791, 1510, 9499, 520, 279, 5300, 3813, 304, 279, 5300, 8316, 11, 439, 264, 2273, 10246, 286, 457, 262, 457, 633, 19182, 11, 1148, 596, 279, 9499, 304, 12366, 1314, 1457, 30, 128009]) + } + }); + // Dynamically-generated tests for (const [tokenizerName, tests] of Object.entries(templates)) {