From 4fa3b7d676f7e0cd1540ae33b3043a6862de2e1e Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sun, 24 Sep 2023 00:56:56 +0200 Subject: [PATCH] Only cut `decoder_input_ids` if past model output --- src/models.js | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/src/models.js b/src/models.js index d7fd5df5a..ac4a78a7c 100644 --- a/src/models.js +++ b/src/models.js @@ -409,10 +409,17 @@ function seq2seqStartBeams(self, inputTokenIds, generation_config, numOutputToke async function seq2seqRunBeam(self, beam) { const input_name = self.main_input_name; + let decoder_input_ids = beam.output_token_ids; + if (beam.prev_model_outputs) { + // After the first step, `prev_model_outputs` won't be null. + // So, we cut decoder_input_ids if past is used + decoder_input_ids = decoder_input_ids.slice(-1); + } + // 1. Prepare let model_inputs = { [input_name]: beam.inputs, - decoder_input_ids: toI64Tensor(beam.output_token_ids.slice(-1)), + decoder_input_ids: toI64Tensor(decoder_input_ids), encoder_outputs: beam.encoder_outputs, past_key_values: beam.prev_model_outputs?.past_key_values, } @@ -3294,15 +3301,6 @@ export class MarianMTModel extends MarianPreTrainedModel { this.num_encoder_heads = this.config.encoder_attention_heads; this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads; } - - - /** - * @param {any} model_inputs - * @returns {Promise} - */ - async forward(model_inputs) { - return await seq2seqForward(this, model_inputs); - } } ////////////////////////////////////////////////// @@ -3335,13 +3333,6 @@ export class M2M100ForConditionalGeneration extends M2M100PreTrainedModel { this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads; } - /** - * @param {any} model_inputs - * @returns {Promise} - */ - async forward(model_inputs) { - return await seq2seqForward(this, model_inputs); - } } //////////////////////////////////////////////////