diff --git a/src/models.js b/src/models.js index 641cf1d17..d7fd5df5a 100644 --- a/src/models.js +++ b/src/models.js @@ -362,18 +362,14 @@ function seq2seqStartBeams(self, inputTokenIds, generation_config, numOutputToke ?? generation_config.eos_token_id; // Support input as tensor or list + // TODO support batched decoder_input_ids if (decoder_input_ids instanceof Tensor) { - if (decoder_input_ids.dims.length === 1) { - decoder_input_ids.dims = [1, decoder_input_ids.dims[0]] - } - decoder_input_ids = decoder_input_ids.tolist(); + decoder_input_ids = decoder_input_ids.tolist().flat(); } else if (!Array.isArray(decoder_input_ids)) { decoder_input_ids = [decoder_input_ids]; } - for (let i = 0; i < inputTokenIds.dims[0]; ++i) { - let tokens = inputTokenIds[i]; - let batch_decoder_input_ids = decoder_input_ids[i].map(Number); + for (let tokens of inputTokenIds) { // TODO: Improve // Currently, just add back batch dimension. // In future, allow for true parallel execution @@ -385,7 +381,7 @@ function seq2seqStartBeams(self, inputTokenIds, generation_config, numOutputToke encoder_outputs: null, prev_model_outputs: null, - output_token_ids: batch_decoder_input_ids, + output_token_ids: decoder_input_ids, done: false, score: 0, id: beamId++ // assign unique id to beams