From c38627e049379908314dda0bf6b297b6c066fd9b Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Mon, 2 Dec 2024 00:26:19 +0000 Subject: [PATCH 1/2] Fix repetition penalty logits processor --- src/generation/logits_process.js | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/generation/logits_process.js b/src/generation/logits_process.js index 732af4f3f..499741105 100644 --- a/src/generation/logits_process.js +++ b/src/generation/logits_process.js @@ -406,12 +406,22 @@ export class NoRepeatNGramLogitsProcessor extends LogitsProcessor { } /** - * A logits processor that penalises repeated output tokens. + * A logits processor that prevents the repetition of previous tokens through a penalty. + * This penalty is applied at most once per token. Note that, for decoder-only models like most LLMs, + * the considered tokens include the prompt. + * + * In the original [paper](https://arxiv.org/pdf/1909.05858.pdf), the authors suggest the use of a + * penalty of around 1.2 to achieve a good balance between truthful generation and lack of repetition. + * To penalize and reduce repetition, use `penalty` values above 1.0, where a higher value penalizes + * more strongly. To reward and encourage repetition, use `penalty` values between 0.0 and 1.0, where + * a lower value rewards more strongly. */ export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor { /** * Create a RepetitionPenaltyLogitsProcessor. - * @param {number} penalty The penalty to apply for repeated tokens. + * @param {number} penalty The parameter for repetition penalty. + * - 1.0 means no penalty. Above 1.0 penalizes previously generated tokens. + * - Between 0.0 and 1.0 rewards previously generated tokens. */ constructor(penalty) { super(); @@ -425,13 +435,9 @@ export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor { * @returns {Object} The logits with repetition penalty processing. */ _call(input_ids, logits) { - // Modify the logits corresponding to each element in `input_ids`. - // As a consequence, the logits corresponding to tokens that appear - // many times in the output will be penalised more. - for (let i = 0; i < input_ids.length; ++i) { const batch_logits_data = /** @type {Float32Array} */(logits[i].data); - for (const input_id of input_ids[i]) { + for (const input_id of new Set(input_ids[i])) { const token = Number(input_id); if (batch_logits_data[token] < 0) { batch_logits_data[token] *= this.penalty; From e48d6eb60efa9bccf445ab214d71605ca618573c Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Mon, 2 Dec 2024 00:26:58 +0000 Subject: [PATCH 2/2] Fix return types of logits processors --- src/generation/logits_process.js | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/generation/logits_process.js b/src/generation/logits_process.js index 499741105..f82634f75 100644 --- a/src/generation/logits_process.js +++ b/src/generation/logits_process.js @@ -151,7 +151,7 @@ export class ForcedBOSTokenLogitsProcessor extends LogitsProcessor { * Apply the BOS token forcing to the logits. * @param {bigint[][]} input_ids The input IDs. * @param {Tensor} logits The logits. - * @returns {Object} The logits with BOS token forcing. + * @returns {Tensor} The logits with BOS token forcing. */ _call(input_ids, logits) { for (let i = 0; i < input_ids.length; ++i) { @@ -221,7 +221,7 @@ export class SuppressTokensAtBeginLogitsProcessor extends LogitsProcessor { * Apply the BOS token forcing to the logits. * @param {bigint[][]} input_ids The input IDs. * @param {Tensor} logits The logits. - * @returns {Object} The logits with BOS token forcing. + * @returns {Tensor} The logits with BOS token forcing. */ _call(input_ids, logits) { for (let i = 0; i < input_ids.length; ++i) { @@ -391,7 +391,7 @@ export class NoRepeatNGramLogitsProcessor extends LogitsProcessor { * Apply the no-repeat-ngram processor to the logits. * @param {bigint[][]} input_ids The input IDs. * @param {Tensor} logits The logits. - * @returns {Object} The logits with no-repeat-ngram processing. + * @returns {Tensor} The logits with no-repeat-ngram processing. */ _call(input_ids, logits) { for (let i = 0; i < input_ids.length; ++i) { @@ -432,7 +432,7 @@ export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor { * Apply the repetition penalty to the logits. * @param {bigint[][]} input_ids The input IDs. * @param {Tensor} logits The logits. - * @returns {Object} The logits with repetition penalty processing. + * @returns {Tensor} The logits with repetition penalty processing. */ _call(input_ids, logits) { for (let i = 0; i < input_ids.length; ++i) { @@ -470,7 +470,7 @@ export class MinLengthLogitsProcessor extends LogitsProcessor { * Apply logit processor. * @param {bigint[][]} input_ids The input IDs. * @param {Tensor} logits The logits. - * @returns {Object} The processed logits. + * @returns {Tensor} The processed logits. */ _call(input_ids, logits) { for (let i = 0; i < input_ids.length; ++i) { @@ -508,7 +508,7 @@ export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor { * Apply logit processor. * @param {bigint[][]} input_ids The input IDs. * @param {Tensor} logits The logits. - * @returns {Object} The processed logits. + * @returns {Tensor} The processed logits. */ _call(input_ids, logits) { for (let i = 0; i < input_ids.length; ++i) { @@ -541,7 +541,7 @@ export class NoBadWordsLogitsProcessor extends LogitsProcessor { * Apply logit processor. * @param {bigint[][]} input_ids The input IDs. * @param {Tensor} logits The logits. - * @returns {Object} The processed logits. + * @returns {Tensor} The processed logits. */ _call(input_ids, logits) { for (let i = 0; i < input_ids.length; ++i) { @@ -602,7 +602,7 @@ export class ClassifierFreeGuidanceLogitsProcessor extends LogitsProcessor { * Apply logit processor. * @param {bigint[][]} input_ids The input IDs. * @param {Tensor} logits The logits. - * @returns {Object} The processed logits. + * @returns {Tensor} The processed logits. */ _call(input_ids, logits) { if (logits.dims[0] !== 2 * input_ids.length) { @@ -656,7 +656,7 @@ export class TemperatureLogitsWarper extends LogitsWarper { * Apply logit warper. * @param {bigint[][]} input_ids The input IDs. * @param {Tensor} logits The logits. - * @returns {Object} The processed logits. + * @returns {Tensor} The processed logits. */ _call(input_ids, logits) { const batch_logits_data = /** @type {Float32Array} */(logits.data);