Skip to content

Commit

Permalink
Add support for min_length and min_new_tokens generation paramete…
Browse files Browse the repository at this point in the history
…rs (#308)

* Add support for `MinNewTokensLengthLogitsProcessor`

* Add support for `MinLengthLogitsProcessor`

* Fix `generation_config` defaults

* Fix `input_ids_seq_length`

* Add unit tests for generation

* Fix generation parameters test case

* Allow specification of multiple `eos_token_ids`
  • Loading branch information
xenova authored Sep 17, 2023
1 parent ef27100 commit 11f6a08
Show file tree
Hide file tree
Showing 3 changed files with 271 additions and 30 deletions.
92 changes: 62 additions & 30 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ import {
WhisperTimeStampLogitsProcessor,
NoRepeatNGramLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
MinLengthLogitsProcessor,
MinNewTokensLengthLogitsProcessor,

Sampler,
} from './utils/generation.js';
Expand Down Expand Up @@ -678,6 +680,7 @@ export class PreTrainedModel extends Callable {
info = await Promise.all([
AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
constructSession(pretrained_model_name_or_path, options.model_file_name ?? 'decoder_model_merged', options),
getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options),
]);

} else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) {
Expand Down Expand Up @@ -782,17 +785,17 @@ export class PreTrainedModel extends Callable {
// processors.push(new NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id));
// }

// if (generation_config.min_length !== null && generation_config.eos_token_id !== null && generation_config.min_length > 0) {
// processors.push(new MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id));
// }
if (generation_config.min_length !== null && generation_config.eos_token_id !== null && generation_config.min_length > 0) {
processors.push(new MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id));
}

// if (generation_config.min_new_tokens !== null && generation_config.eos_token_id !== null && generation_config.min_new_tokens > 0) {
// processors.push(new MinNewTokensLengthLogitsProcessor(
// input_ids_seq_length,
// generation_config.min_new_tokens,
// generation_config.eos_token_id
// ));
// }
if (generation_config.min_new_tokens !== null && generation_config.eos_token_id !== null && generation_config.min_new_tokens > 0) {
processors.push(new MinNewTokensLengthLogitsProcessor(
input_ids_seq_length,
generation_config.min_new_tokens,
generation_config.eos_token_id
));
}

// if (prefix_allowed_tokens_fn !== null) {
// processors.push(new PrefixConstrainedLogitsProcessor(
Expand Down Expand Up @@ -866,7 +869,8 @@ export class PreTrainedModel extends Callable {
*/
_get_generation_config(generation_config) {
// Create empty generation config (contains defaults)
let gen_config = new GenerationConfig();
// We pass `this.config` so that if `eos_token_id` or `bos_token_id` exist in the model's config, we will use them
let gen_config = new GenerationConfig(this.config);

// Apply model's generation config, if it exists
if ('generation_config' in this) {
Expand Down Expand Up @@ -928,7 +932,7 @@ export class PreTrainedModel extends Callable {
input_ids_seq_length = 0;

} else {
input_ids_seq_length = inputs instanceof Tensor ? inputs.dims[0] : inputs.length;
input_ids_seq_length = inputs instanceof Tensor ? inputs.dims.at(-1) : inputs.length;

// decoder-only
if (input_ids_seq_length === 0) {
Expand All @@ -948,6 +952,12 @@ export class PreTrainedModel extends Callable {
logits_processor
)

/** @type {number[]} */
let eos_token_ids = generation_config.eos_token_id;
if (eos_token_ids !== null && !Array.isArray(eos_token_ids)) {
eos_token_ids = [eos_token_ids];
}

// TODO implement early_stopping
// https://huggingface.co/blog/how-to-generate

Expand Down Expand Up @@ -1007,7 +1017,7 @@ export class PreTrainedModel extends Callable {

newBeam.score += logProb;

if (newTokenId === this.config.eos_token_id) {
if (eos_token_ids && eos_token_ids.includes(newTokenId)) {
newBeam.done = true;
}

Expand Down Expand Up @@ -2476,10 +2486,12 @@ export class VisionEncoderDecoderModel extends PreTrainedModel {
* @param {Object} config The configuration object specifying the hyperparameters and other model settings.
* @param {Object} session The ONNX session containing the encoder model.
* @param {any} decoder_merged_session The ONNX session containing the merged decoder model.
* @param {Object} generation_config Configuration object for the generation process.
*/
constructor(config, session, decoder_merged_session) {
constructor(config, session, decoder_merged_session, generation_config) {
super(config, session);
this.decoder_merged_session = decoder_merged_session;
this.generation_config = generation_config;

this.num_layers = this.config.decoder.n_layer;
this.num_heads = this.config.decoder.n_head;
Expand Down Expand Up @@ -2617,9 +2629,11 @@ export class GPT2PreTrainedModel extends PreTrainedModel {
* Creates a new instance of the `GPT2PreTrainedModel` class.
* @param {Object} config The configuration of the model.
* @param {any} session The ONNX session containing the model weights.
* @param {GenerationConfig} generation_config The generation configuration.
*/
constructor(config, session) {
constructor(config, session, generation_config) {
super(config, session);
this.generation_config = generation_config;

// config doesn't contain pad_token_id, so we assume it is the eos_token_id
this.config.pad_token_id = this.config.eos_token_id
Expand Down Expand Up @@ -2649,9 +2663,11 @@ export class GPTNeoPreTrainedModel extends PreTrainedModel {
* Creates a new instance of the `GPTNeoPreTrainedModel` class.
* @param {Object} config The configuration of the model.
* @param {any} session The ONNX session containing the model weights.
* @param {GenerationConfig} generation_config The generation configuration.
*/
constructor(config, session) {
constructor(config, session, generation_config) {
super(config, session);
this.generation_config = generation_config;

// config doesn't contain pad_token_id, so we assume it is the eos_token_id
this.config.pad_token_id = this.config.eos_token_id
Expand All @@ -2673,9 +2689,11 @@ export class GPTNeoXPreTrainedModel extends PreTrainedModel {
* Creates a new instance of the `GPTNeoXPreTrainedModel` class.
* @param {Object} config The configuration of the model.
* @param {any} session The ONNX session containing the model weights.
* @param {GenerationConfig} generation_config The generation configuration.
*/
constructor(config, session) {
constructor(config, session, generation_config) {
super(config, session);
this.generation_config = generation_config;

// config doesn't contain pad_token_id, so we assume it is the eos_token_id
this.config.pad_token_id = this.config.eos_token_id
Expand All @@ -2698,9 +2716,11 @@ export class GPTJPreTrainedModel extends PreTrainedModel {
* Creates a new instance of the `GPTJPreTrainedModel` class.
* @param {Object} config The configuration of the model.
* @param {any} session The ONNX session containing the model weights.
* @param {GenerationConfig} generation_config The generation configuration.
*/
constructor(config, session) {
constructor(config, session, generation_config) {
super(config, session);
this.generation_config = generation_config;

// config doesn't contain pad_token_id, so we assume it is the eos_token_id
this.config.pad_token_id = this.config.eos_token_id
Expand All @@ -2724,9 +2744,11 @@ export class GPTBigCodePreTrainedModel extends PreTrainedModel {
* Creates a new instance of the `GPTBigCodePreTrainedModel` class.
* @param {Object} config The configuration of the model.
* @param {any} session The ONNX session containing the model weights.
* @param {GenerationConfig} generation_config The generation configuration.
*/
constructor(config, session) {
constructor(config, session, generation_config) {
super(config, session);
this.generation_config = generation_config;

// config doesn't contain pad_token_id, so we assume it is the eos_token_id
this.config.pad_token_id = this.config.eos_token_id
Expand All @@ -2747,11 +2769,13 @@ export class GPTBigCodeForCausalLM extends GPTBigCodePreTrainedModel { }
export class CodeGenPreTrainedModel extends PreTrainedModel {
/**
* Creates a new instance of the `CodeGenPreTrainedModel` class.
* @param {Object} config The model configuration object.
* @param {Object} session The ONNX session object.
*/
constructor(config, session) {
* @param {Object} config The model configuration object.
* @param {Object} session The ONNX session object.
* @param {GenerationConfig} generation_config The generation configuration.
*/
constructor(config, session, generation_config) {
super(config, session);
this.generation_config = generation_config;

// config doesn't contain pad_token_id, so we assume it is the eos_token_id
this.config.pad_token_id = this.config.eos_token_id
Expand Down Expand Up @@ -2785,11 +2809,13 @@ export class CodeGenForCausalLM extends CodeGenPreTrainedModel { }
export class LlamaPreTrainedModel extends PreTrainedModel {
/**
* Creates a new instance of the `LlamaPreTrainedModel` class.
* @param {Object} config The model configuration object.
* @param {Object} session The ONNX session object.
*/
constructor(config, session) {
* @param {Object} config The model configuration object.
* @param {Object} session The ONNX session object.
* @param {GenerationConfig} generation_config The generation configuration.
*/
constructor(config, session, generation_config) {
super(config, session);
this.generation_config = generation_config;

// config doesn't contain pad_token_id, so we assume it is the eos_token_id
this.config.pad_token_id = this.config.eos_token_id
Expand Down Expand Up @@ -2817,9 +2843,11 @@ export class BloomPreTrainedModel extends PreTrainedModel {
* Creates a new instance of the `BloomPreTrainedModel` class.
* @param {Object} config The configuration of the model.
* @param {any} session The ONNX session containing the model weights.
* @param {GenerationConfig} generation_config The generation configuration.
*/
constructor(config, session) {
constructor(config, session, generation_config) {
super(config, session);
this.generation_config = generation_config;

// config doesn't contain pad_token_id, so we assume it is the eos_token_id
this.config.pad_token_id = this.config.eos_token_id
Expand Down Expand Up @@ -2848,9 +2876,11 @@ export class MptPreTrainedModel extends PreTrainedModel {
* Creates a new instance of the `MptPreTrainedModel` class.
* @param {Object} config The model configuration object.
* @param {Object} session The ONNX session object.
* @param {GenerationConfig} generation_config The generation configuration.
*/
constructor(config, session) {
constructor(config, session, generation_config) {
super(config, session);
this.generation_config = generation_config;

// config doesn't contain pad_token_id, so we assume it is the eos_token_id
this.config.pad_token_id = this.config.eos_token_id
Expand Down Expand Up @@ -2880,9 +2910,11 @@ export class OPTPreTrainedModel extends PreTrainedModel {
* Creates a new instance of the `OPTPreTrainedModel` class.
* @param {Object} config The model configuration object.
* @param {Object} session The ONNX session object.
* @param {GenerationConfig} generation_config The generation configuration.
*/
constructor(config, session) {
constructor(config, session, generation_config) {
super(config, session);
this.generation_config = generation_config;

// config doesn't contain pad_token_id, so we assume it is the eos_token_id
this.config.pad_token_id = this.config.eos_token_id
Expand Down
71 changes: 71 additions & 0 deletions src/utils/generation.js
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,77 @@ export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
}
}

/**
* A logits processor that enforces a minimum number of tokens.
*
* @extends LogitsProcessor
*/
export class MinLengthLogitsProcessor extends LogitsProcessor {
/**
* Create a MinLengthLogitsProcessor.
* @param {number} min_length The minimum length below which the score of `eos_token_id` is set to negative infinity.
* @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token.
*/
constructor(min_length, eos_token_id) {
super();
this.min_length = min_length;
this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
}

/**
* Apply logit processor.
* @param {Array} input_ids The input IDs.
* @param {Object} logits The logits.
* @returns {Object} The processed logits.
*/
_call(input_ids, logits) {
if (input_ids.length < this.min_length) {
for (const eos_token of this.eos_token_id) {
logits.data[eos_token] = -Infinity;
}
}

return logits
}
}

/**
* A logits processor that enforces a minimum number of new tokens.
*
* @extends LogitsProcessor
*/
export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor {
/**
* Create a MinNewTokensLengthLogitsProcessor.
* @param {number} prompt_length_to_skip The input tokens length.
* @param {number} min_new_tokens The minimum *new* tokens length below which the score of `eos_token_id` is set to negative infinity.
* @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token.
*/
constructor(prompt_length_to_skip, min_new_tokens, eos_token_id) {
super();
this.prompt_length_to_skip = prompt_length_to_skip;
this.min_new_tokens = min_new_tokens;
this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
}

/**
* Apply logit processor.
* @param {Array} input_ids The input IDs.
* @param {Object} logits The logits.
* @returns {Object} The processed logits.
*/
_call(input_ids, logits) {
const new_tokens_length = input_ids.length - this.prompt_length_to_skip;
if (new_tokens_length < this.min_new_tokens) {
for (const eos_token of this.eos_token_id) {
logits.data[eos_token] = -Infinity;
}
}

return logits
}
}

/**
* Class that holds a configuration for a generation task.
*/
Expand Down
Loading

0 comments on commit 11f6a08

Please sign in to comment.