From 11f6a08090caa06c503b0bcff940ae8a6d5d0dc2 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sun, 17 Sep 2023 23:57:13 +0200 Subject: [PATCH] Add support for `min_length` and `min_new_tokens` generation parameters (#308) * Add support for `MinNewTokensLengthLogitsProcessor` * Add support for `MinLengthLogitsProcessor` * Fix `generation_config` defaults * Fix `input_ids_seq_length` * Add unit tests for generation * Fix generation parameters test case * Allow specification of multiple `eos_token_ids` --- src/models.js | 92 +++++++++++++++++--------- src/utils/generation.js | 71 ++++++++++++++++++++ tests/generation.test.js | 138 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 271 insertions(+), 30 deletions(-) create mode 100644 tests/generation.test.js diff --git a/src/models.js b/src/models.js index b7f0c8214..4cc1727ae 100644 --- a/src/models.js +++ b/src/models.js @@ -64,6 +64,8 @@ import { WhisperTimeStampLogitsProcessor, NoRepeatNGramLogitsProcessor, RepetitionPenaltyLogitsProcessor, + MinLengthLogitsProcessor, + MinNewTokensLengthLogitsProcessor, Sampler, } from './utils/generation.js'; @@ -678,6 +680,7 @@ export class PreTrainedModel extends Callable { info = await Promise.all([ AutoConfig.from_pretrained(pretrained_model_name_or_path, options), constructSession(pretrained_model_name_or_path, options.model_file_name ?? 'decoder_model_merged', options), + getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options), ]); } else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) { @@ -782,17 +785,17 @@ export class PreTrainedModel extends Callable { // processors.push(new NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id)); // } - // if (generation_config.min_length !== null && generation_config.eos_token_id !== null && generation_config.min_length > 0) { - // processors.push(new MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)); - // } + if (generation_config.min_length !== null && generation_config.eos_token_id !== null && generation_config.min_length > 0) { + processors.push(new MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)); + } - // if (generation_config.min_new_tokens !== null && generation_config.eos_token_id !== null && generation_config.min_new_tokens > 0) { - // processors.push(new MinNewTokensLengthLogitsProcessor( - // input_ids_seq_length, - // generation_config.min_new_tokens, - // generation_config.eos_token_id - // )); - // } + if (generation_config.min_new_tokens !== null && generation_config.eos_token_id !== null && generation_config.min_new_tokens > 0) { + processors.push(new MinNewTokensLengthLogitsProcessor( + input_ids_seq_length, + generation_config.min_new_tokens, + generation_config.eos_token_id + )); + } // if (prefix_allowed_tokens_fn !== null) { // processors.push(new PrefixConstrainedLogitsProcessor( @@ -866,7 +869,8 @@ export class PreTrainedModel extends Callable { */ _get_generation_config(generation_config) { // Create empty generation config (contains defaults) - let gen_config = new GenerationConfig(); + // We pass `this.config` so that if `eos_token_id` or `bos_token_id` exist in the model's config, we will use them + let gen_config = new GenerationConfig(this.config); // Apply model's generation config, if it exists if ('generation_config' in this) { @@ -928,7 +932,7 @@ export class PreTrainedModel extends Callable { input_ids_seq_length = 0; } else { - input_ids_seq_length = inputs instanceof Tensor ? inputs.dims[0] : inputs.length; + input_ids_seq_length = inputs instanceof Tensor ? inputs.dims.at(-1) : inputs.length; // decoder-only if (input_ids_seq_length === 0) { @@ -948,6 +952,12 @@ export class PreTrainedModel extends Callable { logits_processor ) + /** @type {number[]} */ + let eos_token_ids = generation_config.eos_token_id; + if (eos_token_ids !== null && !Array.isArray(eos_token_ids)) { + eos_token_ids = [eos_token_ids]; + } + // TODO implement early_stopping // https://huggingface.co/blog/how-to-generate @@ -1007,7 +1017,7 @@ export class PreTrainedModel extends Callable { newBeam.score += logProb; - if (newTokenId === this.config.eos_token_id) { + if (eos_token_ids && eos_token_ids.includes(newTokenId)) { newBeam.done = true; } @@ -2476,10 +2486,12 @@ export class VisionEncoderDecoderModel extends PreTrainedModel { * @param {Object} config The configuration object specifying the hyperparameters and other model settings. * @param {Object} session The ONNX session containing the encoder model. * @param {any} decoder_merged_session The ONNX session containing the merged decoder model. + * @param {Object} generation_config Configuration object for the generation process. */ - constructor(config, session, decoder_merged_session) { + constructor(config, session, decoder_merged_session, generation_config) { super(config, session); this.decoder_merged_session = decoder_merged_session; + this.generation_config = generation_config; this.num_layers = this.config.decoder.n_layer; this.num_heads = this.config.decoder.n_head; @@ -2617,9 +2629,11 @@ export class GPT2PreTrainedModel extends PreTrainedModel { * Creates a new instance of the `GPT2PreTrainedModel` class. * @param {Object} config The configuration of the model. * @param {any} session The ONNX session containing the model weights. + * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session) { + constructor(config, session, generation_config) { super(config, session); + this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id this.config.pad_token_id = this.config.eos_token_id @@ -2649,9 +2663,11 @@ export class GPTNeoPreTrainedModel extends PreTrainedModel { * Creates a new instance of the `GPTNeoPreTrainedModel` class. * @param {Object} config The configuration of the model. * @param {any} session The ONNX session containing the model weights. + * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session) { + constructor(config, session, generation_config) { super(config, session); + this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id this.config.pad_token_id = this.config.eos_token_id @@ -2673,9 +2689,11 @@ export class GPTNeoXPreTrainedModel extends PreTrainedModel { * Creates a new instance of the `GPTNeoXPreTrainedModel` class. * @param {Object} config The configuration of the model. * @param {any} session The ONNX session containing the model weights. + * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session) { + constructor(config, session, generation_config) { super(config, session); + this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id this.config.pad_token_id = this.config.eos_token_id @@ -2698,9 +2716,11 @@ export class GPTJPreTrainedModel extends PreTrainedModel { * Creates a new instance of the `GPTJPreTrainedModel` class. * @param {Object} config The configuration of the model. * @param {any} session The ONNX session containing the model weights. + * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session) { + constructor(config, session, generation_config) { super(config, session); + this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id this.config.pad_token_id = this.config.eos_token_id @@ -2724,9 +2744,11 @@ export class GPTBigCodePreTrainedModel extends PreTrainedModel { * Creates a new instance of the `GPTBigCodePreTrainedModel` class. * @param {Object} config The configuration of the model. * @param {any} session The ONNX session containing the model weights. + * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session) { + constructor(config, session, generation_config) { super(config, session); + this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id this.config.pad_token_id = this.config.eos_token_id @@ -2747,11 +2769,13 @@ export class GPTBigCodeForCausalLM extends GPTBigCodePreTrainedModel { } export class CodeGenPreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `CodeGenPreTrainedModel` class. - * @param {Object} config The model configuration object. - * @param {Object} session The ONNX session object. - */ - constructor(config, session) { + * @param {Object} config The model configuration object. + * @param {Object} session The ONNX session object. + * @param {GenerationConfig} generation_config The generation configuration. + */ + constructor(config, session, generation_config) { super(config, session); + this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id this.config.pad_token_id = this.config.eos_token_id @@ -2785,11 +2809,13 @@ export class CodeGenForCausalLM extends CodeGenPreTrainedModel { } export class LlamaPreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `LlamaPreTrainedModel` class. - * @param {Object} config The model configuration object. - * @param {Object} session The ONNX session object. - */ - constructor(config, session) { + * @param {Object} config The model configuration object. + * @param {Object} session The ONNX session object. + * @param {GenerationConfig} generation_config The generation configuration. + */ + constructor(config, session, generation_config) { super(config, session); + this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id this.config.pad_token_id = this.config.eos_token_id @@ -2817,9 +2843,11 @@ export class BloomPreTrainedModel extends PreTrainedModel { * Creates a new instance of the `BloomPreTrainedModel` class. * @param {Object} config The configuration of the model. * @param {any} session The ONNX session containing the model weights. + * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session) { + constructor(config, session, generation_config) { super(config, session); + this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id this.config.pad_token_id = this.config.eos_token_id @@ -2848,9 +2876,11 @@ export class MptPreTrainedModel extends PreTrainedModel { * Creates a new instance of the `MptPreTrainedModel` class. * @param {Object} config The model configuration object. * @param {Object} session The ONNX session object. + * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session) { + constructor(config, session, generation_config) { super(config, session); + this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id this.config.pad_token_id = this.config.eos_token_id @@ -2880,9 +2910,11 @@ export class OPTPreTrainedModel extends PreTrainedModel { * Creates a new instance of the `OPTPreTrainedModel` class. * @param {Object} config The model configuration object. * @param {Object} session The ONNX session object. + * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session) { + constructor(config, session, generation_config) { super(config, session); + this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id this.config.pad_token_id = this.config.eos_token_id diff --git a/src/utils/generation.js b/src/utils/generation.js index 86806921f..588b1a562 100644 --- a/src/utils/generation.js +++ b/src/utils/generation.js @@ -420,6 +420,77 @@ export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor { } } +/** + * A logits processor that enforces a minimum number of tokens. + * + * @extends LogitsProcessor + */ +export class MinLengthLogitsProcessor extends LogitsProcessor { + /** + * Create a MinLengthLogitsProcessor. + * @param {number} min_length The minimum length below which the score of `eos_token_id` is set to negative infinity. + * @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token. + */ + constructor(min_length, eos_token_id) { + super(); + this.min_length = min_length; + this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id]; + } + + /** + * Apply logit processor. + * @param {Array} input_ids The input IDs. + * @param {Object} logits The logits. + * @returns {Object} The processed logits. + */ + _call(input_ids, logits) { + if (input_ids.length < this.min_length) { + for (const eos_token of this.eos_token_id) { + logits.data[eos_token] = -Infinity; + } + } + + return logits + } +} + +/** + * A logits processor that enforces a minimum number of new tokens. + * + * @extends LogitsProcessor + */ +export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor { + /** + * Create a MinNewTokensLengthLogitsProcessor. + * @param {number} prompt_length_to_skip The input tokens length. + * @param {number} min_new_tokens The minimum *new* tokens length below which the score of `eos_token_id` is set to negative infinity. + * @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token. + */ + constructor(prompt_length_to_skip, min_new_tokens, eos_token_id) { + super(); + this.prompt_length_to_skip = prompt_length_to_skip; + this.min_new_tokens = min_new_tokens; + this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id]; + } + + /** + * Apply logit processor. + * @param {Array} input_ids The input IDs. + * @param {Object} logits The logits. + * @returns {Object} The processed logits. + */ + _call(input_ids, logits) { + const new_tokens_length = input_ids.length - this.prompt_length_to_skip; + if (new_tokens_length < this.min_new_tokens) { + for (const eos_token of this.eos_token_id) { + logits.data[eos_token] = -Infinity; + } + } + + return logits + } +} + /** * Class that holds a configuration for a generation task. */ diff --git a/tests/generation.test.js b/tests/generation.test.js new file mode 100644 index 000000000..43e18675e --- /dev/null +++ b/tests/generation.test.js @@ -0,0 +1,138 @@ + +import { pipeline } from '../src/transformers.js'; +import { init, m, MAX_TEST_EXECUTION_TIME } from './init.js'; + +// Initialise the testing environment +init(); + +describe('Generation parameters', () => { + + // List all models which will be tested + const models = [ + 'Xenova/LaMini-Flan-T5-77M', // encoder-decoder + 'Xenova/LaMini-GPT-124M', // decoder-only + ]; + + // encoder-decoder model + it(models[0], async () => { + const text = 'how can I become more healthy?'; + + const generator = await pipeline('text2text-generation', m(models[0])); + + // default + // NOTE: Since `max_length` defaults to 20, this case also tests that. + { + const outputs = await generator(text); + + const tokens = generator.tokenizer.encode(outputs[0]) + expect(tokens.length).toEqual(20); + } + + // max_new_tokens + { + // NOTE: Without setting `min_new_tokens` (but setting `max_new_tokens`), 64 tokens are generated. + // So, the following tests are valid. + const MAX_NEW_TOKENS = 20; + const outputs = await generator(text, { + max_new_tokens: MAX_NEW_TOKENS, + }); + + const tokens = generator.tokenizer.encode(outputs[0]) + expect(tokens.length).toEqual(MAX_NEW_TOKENS + 1); // + 1 due to forced BOS token + } + + // min_length + { + // NOTE: Without setting `min_length` (but setting `max_new_tokens`), 64 tokens are generated. + // So, the following tests are valid. + const MAX_NEW_TOKENS = 128; + const MIN_LENGTH = 65; + const outputs = await generator(text, { + max_new_tokens: MAX_NEW_TOKENS, + min_length: MIN_LENGTH, + }); + + const tokens = generator.tokenizer.encode(outputs[0]) + expect(tokens.length).toBeGreaterThanOrEqual(MIN_LENGTH); + } + + // min_new_tokens + { + // NOTE: Without setting `min_new_tokens` (but setting `max_new_tokens`), 64 tokens are generated. + // So, the following tests are valid. + const MAX_NEW_TOKENS = 128; + const MIN_NEW_TOKENS = 65; + const outputs = await generator(text, { + max_new_tokens: MAX_NEW_TOKENS, + min_new_tokens: MIN_NEW_TOKENS, + }); + + const tokens = generator.tokenizer.encode(outputs[0]) + expect(tokens.length).toBeGreaterThanOrEqual(MIN_NEW_TOKENS); + } + + await generator.dispose(); + + }, MAX_TEST_EXECUTION_TIME); + + // decoder-only model + it(models[1], async () => { + const text = "### Instruction:\nTrue or False: The earth is flat?\n\n### Response: "; + + const generator = await pipeline('text-generation', m(models[1])); + + // default + // NOTE: Since `max_length` defaults to 20, this case also tests that. + { + const outputs = await generator(text); + const tokens = generator.tokenizer.encode(outputs[0].generated_text) + expect(tokens.length).toEqual(20); + } + + // max_new_tokens + { + const MAX_NEW_TOKENS = 20; + const outputs = await generator(text, { + max_new_tokens: MAX_NEW_TOKENS, + }); + const promptTokens = generator.tokenizer.encode(text) + const tokens = generator.tokenizer.encode(outputs[0].generated_text) + expect(tokens.length).toBeGreaterThan(promptTokens.length); + } + + // min_length + { + // NOTE: Without setting `min_length` (but setting `max_new_tokens`), 22 tokens are generated. + // So, the following tests are valid. + const MAX_NEW_TOKENS = 10; + const MIN_LENGTH = 25; + const outputs = await generator(text, { + max_new_tokens: MAX_NEW_TOKENS, + min_length: MIN_LENGTH, + }); + + const tokens = generator.tokenizer.encode(outputs[0].generated_text) + expect(tokens.length).toBeGreaterThanOrEqual(MIN_LENGTH); + } + + // min_new_tokens + { + // NOTE: Without setting `min_new_tokens` (but setting `max_new_tokens`), 22 tokens are generated. + // So, the following tests are valid. + const MAX_NEW_TOKENS = 32; + const MIN_NEW_TOKENS = 10; + const outputs = await generator(text, { + max_new_tokens: MAX_NEW_TOKENS, + min_new_tokens: MIN_NEW_TOKENS, + }); + + const tokens = generator.tokenizer.encode(outputs[0].generated_text) + const promptTokens = generator.tokenizer.encode(text) + expect(tokens.length).toBeGreaterThanOrEqual(promptTokens.length + MIN_NEW_TOKENS); + } + + await generator.dispose(); + + }, MAX_TEST_EXECUTION_TIME); + +}); \ No newline at end of file