Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix repetition penalty logits processor #1062

Merged
merged 2 commits into from
Dec 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions src/generation/logits_process.js
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ export class ForcedBOSTokenLogitsProcessor extends LogitsProcessor {
* Apply the BOS token forcing to the logits.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The logits with BOS token forcing.
* @returns {Tensor} The logits with BOS token forcing.
*/
_call(input_ids, logits) {
for (let i = 0; i < input_ids.length; ++i) {
Expand Down Expand Up @@ -221,7 +221,7 @@ export class SuppressTokensAtBeginLogitsProcessor extends LogitsProcessor {
* Apply the BOS token forcing to the logits.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The logits with BOS token forcing.
* @returns {Tensor} The logits with BOS token forcing.
*/
_call(input_ids, logits) {
for (let i = 0; i < input_ids.length; ++i) {
Expand Down Expand Up @@ -391,7 +391,7 @@ export class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
* Apply the no-repeat-ngram processor to the logits.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The logits with no-repeat-ngram processing.
* @returns {Tensor} The logits with no-repeat-ngram processing.
*/
_call(input_ids, logits) {
for (let i = 0; i < input_ids.length; ++i) {
Expand All @@ -406,12 +406,22 @@ export class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
}

/**
* A logits processor that penalises repeated output tokens.
* A logits processor that prevents the repetition of previous tokens through a penalty.
* This penalty is applied at most once per token. Note that, for decoder-only models like most LLMs,
* the considered tokens include the prompt.
*
* In the original [paper](https://arxiv.org/pdf/1909.05858.pdf), the authors suggest the use of a
* penalty of around 1.2 to achieve a good balance between truthful generation and lack of repetition.
* To penalize and reduce repetition, use `penalty` values above 1.0, where a higher value penalizes
* more strongly. To reward and encourage repetition, use `penalty` values between 0.0 and 1.0, where
* a lower value rewards more strongly.
*/
export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
/**
* Create a RepetitionPenaltyLogitsProcessor.
* @param {number} penalty The penalty to apply for repeated tokens.
* @param {number} penalty The parameter for repetition penalty.
* - 1.0 means no penalty. Above 1.0 penalizes previously generated tokens.
* - Between 0.0 and 1.0 rewards previously generated tokens.
*/
constructor(penalty) {
super();
Expand All @@ -422,16 +432,12 @@ export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
* Apply the repetition penalty to the logits.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The logits with repetition penalty processing.
* @returns {Tensor} The logits with repetition penalty processing.
*/
_call(input_ids, logits) {
// Modify the logits corresponding to each element in `input_ids`.
// As a consequence, the logits corresponding to tokens that appear
// many times in the output will be penalised more.

for (let i = 0; i < input_ids.length; ++i) {
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
for (const input_id of input_ids[i]) {
for (const input_id of new Set(input_ids[i])) {
const token = Number(input_id);
if (batch_logits_data[token] < 0) {
batch_logits_data[token] *= this.penalty;
Expand Down Expand Up @@ -464,7 +470,7 @@ export class MinLengthLogitsProcessor extends LogitsProcessor {
* Apply logit processor.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The processed logits.
* @returns {Tensor} The processed logits.
*/
_call(input_ids, logits) {
for (let i = 0; i < input_ids.length; ++i) {
Expand Down Expand Up @@ -502,7 +508,7 @@ export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor {
* Apply logit processor.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The processed logits.
* @returns {Tensor} The processed logits.
*/
_call(input_ids, logits) {
for (let i = 0; i < input_ids.length; ++i) {
Expand Down Expand Up @@ -535,7 +541,7 @@ export class NoBadWordsLogitsProcessor extends LogitsProcessor {
* Apply logit processor.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The processed logits.
* @returns {Tensor} The processed logits.
*/
_call(input_ids, logits) {
for (let i = 0; i < input_ids.length; ++i) {
Expand Down Expand Up @@ -596,7 +602,7 @@ export class ClassifierFreeGuidanceLogitsProcessor extends LogitsProcessor {
* Apply logit processor.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The processed logits.
* @returns {Tensor} The processed logits.
*/
_call(input_ids, logits) {
if (logits.dims[0] !== 2 * input_ids.length) {
Expand Down Expand Up @@ -650,7 +656,7 @@ export class TemperatureLogitsWarper extends LogitsWarper {
* Apply logit warper.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The processed logits.
* @returns {Tensor} The processed logits.
*/
_call(input_ids, logits) {
const batch_logits_data = /** @type {Float32Array} */(logits.data);
Expand Down