Skip to content

Commit

Permalink
Add support for MinLengthLogitsProcessor
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Sep 16, 2023
1 parent f1d1263 commit 633b28b
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ import {
WhisperTimeStampLogitsProcessor,
NoRepeatNGramLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
MinLengthLogitsProcessor,
MinNewTokensLengthLogitsProcessor,

Sampler,
Expand Down Expand Up @@ -783,9 +784,9 @@ export class PreTrainedModel extends Callable {
// processors.push(new NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id));
// }

// if (generation_config.min_length !== null && generation_config.eos_token_id !== null && generation_config.min_length > 0) {
// processors.push(new MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id));
// }
if (generation_config.min_length !== null && generation_config.eos_token_id !== null && generation_config.min_length > 0) {
processors.push(new MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id));
}

if (generation_config.min_new_tokens !== null && generation_config.eos_token_id !== null && generation_config.min_new_tokens > 0) {
processors.push(new MinNewTokensLengthLogitsProcessor(
Expand Down
34 changes: 34 additions & 0 deletions src/utils/generation.js
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,40 @@ export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
}
}

/**
* A logits processor that enforces a minimum number of tokens.
*
* @extends LogitsProcessor
*/
export class MinLengthLogitsProcessor extends LogitsProcessor {
/**
* Create a MinLengthLogitsProcessor.
* @param {number} min_length The minimum length below which the score of `eos_token_id` is set to negative infinity.
* @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token.
*/
constructor(min_length, eos_token_id) {
super();
this.min_length = min_length;
this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
}

/**
* Apply logit processor.
* @param {Array} input_ids The input IDs.
* @param {Object} logits The logits.
* @returns {Object} The processed logits.
*/
_call(input_ids, logits) {
if (input_ids.length < this.min_length) {
for (const eos_token of this.eos_token_id) {
logits.data[eos_token] = -Infinity;
}
}

return logits
}
}

/**
* A logits processor that enforces a minimum number of new tokens.
*
Expand Down

0 comments on commit 633b28b

Please sign in to comment.