From 154858c24e34d5d6acc4b3f37db043d0fc148d1c Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sun, 10 Sep 2023 02:54:39 +0200 Subject: [PATCH 1/9] Add support for `Blenderbot` models Closes #37 References #29 --- src/models.js | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/src/models.js b/src/models.js index b7f0c8214..5c1e2b20f 100644 --- a/src/models.js +++ b/src/models.js @@ -2033,6 +2033,45 @@ export class MBartForSequenceClassification extends MBartPreTrainedModel { ////////////////////////////////////////////////// + +////////////////////////////////////////////////// +// Blenderbot models +export class BlenderbotPreTrainedModel extends PreTrainedModel { }; + +/** + * The bare Blenderbot Model outputting raw hidden-states without any specific head on top. + */ +export class BlenderbotModel extends BlenderbotPreTrainedModel { } + +/** + * The Blenderbot Model with a language modeling head. Can be used for summarization. + */ +export class BlenderbotForConditionalGeneration extends BlenderbotPreTrainedModel { + + /** + * Creates a new instance of the `BlenderbotForConditionalGeneration` class. + * @param {any} config The model configuration. + * @param {any} session The ONNX session containing the encoder weights. + * @param {any} decoder_merged_session The ONNX session containing the merged decoder weights. + * @param {GenerationConfig} generation_config The generation configuration. + */ + constructor(config, session, decoder_merged_session, generation_config) { + super(config, session); + this.decoder_merged_session = decoder_merged_session; + this.generation_config = generation_config; + + this.num_decoder_layers = this.config.decoder_layers; + this.num_decoder_heads = this.config.decoder_attention_heads; + this.decoder_dim_kv = this.config.d_model / this.num_decoder_heads; + + this.num_encoder_layers = this.config.encoder_layers; + this.num_encoder_heads = this.config.encoder_attention_heads; + this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads; + } +} +////////////////////////////////////////////////// + + ////////////////////////////////////////////////// // Roberta models export class RobertaPreTrainedModel extends PreTrainedModel { } @@ -3417,6 +3456,7 @@ const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([ ['marian', MarianModel], ['whisper', WhisperModel], ['m2m_100', M2M100Model], + ['blenderbot', BlenderbotModel], ]); @@ -3470,6 +3510,7 @@ const MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES = new Map([ ['whisper', WhisperForConditionalGeneration], ['marian', MarianMTModel], ['m2m_100', M2M100ForConditionalGeneration], + ['blenderbot', BlenderbotForConditionalGeneration], ]); const MODEL_WITH_LM_HEAD_MAPPING_NAMES = new Map([ From fa40795233eb1322ca0d711d2a4917f63562efdc Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sun, 10 Sep 2023 02:54:51 +0200 Subject: [PATCH 2/9] Add support for `BlenderbotTokenizer` --- src/tokenizers.js | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/tokenizers.js b/src/tokenizers.js index d0f47923e..10dee268a 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -3705,6 +3705,9 @@ export class MarianTokenizer extends PreTrainedTokenizer { export class Wav2Vec2CTCTokenizer extends PreTrainedTokenizer { } +export class BlenderbotTokenizer extends PreTrainedTokenizer { } +// export class BlenderbotSmallTokenizer extends PreTrainedTokenizer { } + /** * Helper class which is used to instantiate pretrained tokenizers with the `from_pretrained` function. * The chosen tokenizer class is determined by the type specified in the tokenizer config. @@ -3744,6 +3747,7 @@ export class AutoTokenizer { FalconTokenizer, GPTNeoXTokenizer, Wav2Vec2CTCTokenizer, + BlenderbotTokenizer, // Base case: PreTrainedTokenizer, From 0de885a6ac2d0f219dde5ca62fe932806122a60b Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Mon, 11 Sep 2023 01:25:48 +0200 Subject: [PATCH 3/9] Add blenderbot to supported models --- scripts/supported_models.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/scripts/supported_models.py b/scripts/supported_models.py index 0b4312d34..c998ef954 100644 --- a/scripts/supported_models.py +++ b/scripts/supported_models.py @@ -97,16 +97,16 @@ 'bert-base-chinese', 'emilyalsentzer/Bio_ClinicalBERT', ], - # 'blenderbot': [ - # # Text2text generation (TODO add conversational) - # 'facebook/blenderbot-400M-distill', - # 'facebook/blenderbot-1B-distill', - # ], - # 'blenderbot-small': [ - # # Text2text generation (TODO add conversational) - # 'facebook/blenderbot-90M', # DEPRECATED - # 'facebook/blenderbot_small-90M', - # ], + 'blenderbot': [ + # Text2text generation (TODO add conversational) + 'facebook/blenderbot-400M-distill', + # 'facebook/blenderbot-1B-distill', + ], + 'blenderbot-small': [ + # Text2text generation (TODO add conversational) + # 'facebook/blenderbot-90M', # DEPRECATED + 'facebook/blenderbot_small-90M', + ], 'bloom': [ # Text generation 'bigscience/bloom-560m', From 40df9934bf5e6216dff58ece5caa061c2f9514bc Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Mon, 11 Sep 2023 04:14:52 +0200 Subject: [PATCH 4/9] Add support for `BlenderbotSmallTokenizer` --- src/tokenizers.js | 45 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/src/tokenizers.js b/src/tokenizers.js index 10dee268a..9a0d295eb 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -516,6 +516,7 @@ class BPE extends TokenizerModel { * @param {Object} config.vocab A mapping of tokens to ids. * @param {string} config.unk_token The unknown token used for out of vocabulary words. * @param {string} config.end_of_word_suffix The suffix to place at the end of each word. + * @param {string} [config.continuing_subword_suffix] The suffix to insert between words. * @param {Array} config.merges An array of BPE merges as strings. */ constructor(config) { @@ -539,6 +540,9 @@ class BPE extends TokenizerModel { this.end_of_word_suffix = config.end_of_word_suffix; + // NOTE: `continuing_subword_suffix` is custom (to support `BlenderbotSmallTokenizer`) + this.continuing_subword_suffix = config.continuing_subword_suffix ?? null; + this.byte_fallback = this.config.byte_fallback ?? false; if (this.byte_fallback) { @@ -665,6 +669,14 @@ class BPE extends TokenizerModel { result = word; } + // Possibly append suffix + if (this.continuing_subword_suffix) { + // Do not append suffix to the last token + for (let i = 0; i < result.length - 1; ++i) { + result[i] += this.continuing_subword_suffix; + } + } + // Save the result to the cache this.cache.set(token, result); @@ -1116,6 +1128,8 @@ class PreTokenizer extends Callable { return new PunctuationPreTokenizer(config); case 'Digits': return new DigitsPreTokenizer(config); + case 'Replace': + return new ReplacePreTokenizer(config); default: throw new Error(`Unknown PreTokenizer type: ${config.type}`); } @@ -2079,6 +2093,34 @@ class WhitespaceSplit extends PreTokenizer { } } +// NOTE: `ReplacePreTokenizer` is custom (to support `BlenderbotSmallTokenizer`) +class ReplacePreTokenizer extends PreTokenizer { + /** + * @param {Object} config The configuration options for the pre-tokenizer. + * @param {Object} config.pattern The pattern used to split the text. Can be a string or a regex object. + * @param {string} config.content What to replace the pattern with. + */ + constructor(config) { + super(); + this.config = config; + this.pattern = createPattern(this.config.pattern); + this.content = this.config.content; + } + + /** + * Pre-tokenizes the input text by replacing certain characters. + * @param {string} text The text to be pre-tokenized. + * @returns {string[]} An array of tokens produced by replacing certain characters. + */ + pre_tokenize_text(text) { + if (this.pattern === null) { + return [text]; + } + return [text.replaceAll(this.pattern, this.config.content)]; + } +} + + export class PreTrainedTokenizer extends Callable { /** * Create a new PreTrainedTokenizer instance. @@ -3706,7 +3748,7 @@ export class MarianTokenizer extends PreTrainedTokenizer { export class Wav2Vec2CTCTokenizer extends PreTrainedTokenizer { } export class BlenderbotTokenizer extends PreTrainedTokenizer { } -// export class BlenderbotSmallTokenizer extends PreTrainedTokenizer { } +export class BlenderbotSmallTokenizer extends PreTrainedTokenizer { } /** * Helper class which is used to instantiate pretrained tokenizers with the `from_pretrained` function. @@ -3748,6 +3790,7 @@ export class AutoTokenizer { GPTNeoXTokenizer, Wav2Vec2CTCTokenizer, BlenderbotTokenizer, + BlenderbotSmallTokenizer, // Base case: PreTrainedTokenizer, From 1f7ea19a2d5edc7a211bc54b8c13763762e40777 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Mon, 11 Sep 2023 04:15:05 +0200 Subject: [PATCH 5/9] Add custom tests for blenderbot-small --- tests/generate_tests.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/generate_tests.py b/tests/generate_tests.py index 37fd5e43b..2e816459a 100644 --- a/tests/generate_tests.py +++ b/tests/generate_tests.py @@ -63,6 +63,13 @@ "weird \uFF5E edge \uFF5E case", ], "custom": { + "facebook/blenderbot_small-90M": [ + # Test special tokens + "__start__hello world__end__", + # The original (python) tokenizer simply joins by spaces (regardless of special tokens or not) + "__start__ hey __end__" # --> ... --> "__start__ hey __end__" + "__start__hey __end__" # --> ... --> "__start__ hey __end__" + ], "tiiuae/falcon-7b": [ "12 and 123 and 1234", # Special case for splitting on 3 numbers ], From cce3a1f57f273f4b926e7aad2c53006c6392f508 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Mon, 11 Sep 2023 04:19:02 +0200 Subject: [PATCH 6/9] Add support for `BlenderbotSmall` models --- src/models.js | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/src/models.js b/src/models.js index 5c1e2b20f..0db929f10 100644 --- a/src/models.js +++ b/src/models.js @@ -2072,6 +2072,44 @@ export class BlenderbotForConditionalGeneration extends BlenderbotPreTrainedMode ////////////////////////////////////////////////// +////////////////////////////////////////////////// +// Blenderbot models +export class BlenderbotSmallPreTrainedModel extends PreTrainedModel { }; + +/** + * The bare BlenderbotSmall Model outputting raw hidden-states without any specific head on top. + */ +export class BlenderbotSmallModel extends BlenderbotSmallPreTrainedModel { } + +/** + * The BlenderbotSmall Model with a language modeling head. Can be used for summarization. + */ +export class BlenderbotSmallForConditionalGeneration extends BlenderbotSmallPreTrainedModel { + + /** + * Creates a new instance of the `BlenderbotForConditionalGeneration` class. + * @param {any} config The model configuration. + * @param {any} session The ONNX session containing the encoder weights. + * @param {any} decoder_merged_session The ONNX session containing the merged decoder weights. + * @param {GenerationConfig} generation_config The generation configuration. + */ + constructor(config, session, decoder_merged_session, generation_config) { + super(config, session); + this.decoder_merged_session = decoder_merged_session; + this.generation_config = generation_config; + + this.num_decoder_layers = this.config.decoder_layers; + this.num_decoder_heads = this.config.decoder_attention_heads; + this.decoder_dim_kv = this.config.d_model / this.num_decoder_heads; + + this.num_encoder_layers = this.config.encoder_layers; + this.num_encoder_heads = this.config.encoder_attention_heads; + this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads; + } +} +////////////////////////////////////////////////// + + ////////////////////////////////////////////////// // Roberta models export class RobertaPreTrainedModel extends PreTrainedModel { } @@ -3457,6 +3495,7 @@ const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([ ['whisper', WhisperModel], ['m2m_100', M2M100Model], ['blenderbot', BlenderbotModel], + ['blenderbot-small', BlenderbotSmallModel], ]); @@ -3511,6 +3550,7 @@ const MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES = new Map([ ['marian', MarianMTModel], ['m2m_100', M2M100ForConditionalGeneration], ['blenderbot', BlenderbotForConditionalGeneration], + ['blenderbot-small', BlenderbotSmallForConditionalGeneration], ]); const MODEL_WITH_LM_HEAD_MAPPING_NAMES = new Map([ From 0dbe7b28edbf66a1f06e102f494caa9f72fbd675 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Mon, 11 Sep 2023 04:40:23 +0200 Subject: [PATCH 7/9] Update list of supported models --- README.md | 2 ++ docs/snippets/6_supported-models.snippet | 2 ++ 2 files changed, 4 insertions(+) diff --git a/README.md b/README.md index 815ed4c5e..785af5c11 100644 --- a/README.md +++ b/README.md @@ -257,6 +257,8 @@ You can refine your search by selecting the task you're interested in (e.g., [te 1. **[BART](https://huggingface.co/docs/transformers/model_doc/bart)** (from Facebook) released with the paper [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461) by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer. 1. **[BEiT](https://huggingface.co/docs/transformers/model_doc/beit)** (from Microsoft) released with the paper [BEiT: BERT Pre-Training of Image Transformers](https://arxiv.org/abs/2106.08254) by Hangbo Bao, Li Dong, Furu Wei. 1. **[BERT](https://huggingface.co/docs/transformers/model_doc/bert)** (from Google) released with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. +1. **[Blenderbot](https://huggingface.co/docs/transformers/model_doc/blenderbot)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston. +1. **[BlenderbotSmall](https://huggingface.co/docs/transformers/model_doc/blenderbot-small)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston. 1. **[BLOOM](https://huggingface.co/docs/transformers/model_doc/bloom)** (from BigScience workshop) released by the [BigScience Workshop](https://bigscience.huggingface.co/). 1. **[CamemBERT](https://huggingface.co/docs/transformers/model_doc/camembert)** (from Inria/Facebook/Sorbonne) released with the paper [CamemBERT: a Tasty French Language Model](https://arxiv.org/abs/1911.03894) by Louis Martin*, Benjamin Muller*, Pedro Javier Ortiz Suárez*, Yoann Dupont, Laurent Romary, Éric Villemonte de la Clergerie, Djamé Seddah and Benoît Sagot. 1. **[CLIP](https://huggingface.co/docs/transformers/model_doc/clip)** (from OpenAI) released with the paper [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) by Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever. diff --git a/docs/snippets/6_supported-models.snippet b/docs/snippets/6_supported-models.snippet index 72582bc83..03ee35b4d 100644 --- a/docs/snippets/6_supported-models.snippet +++ b/docs/snippets/6_supported-models.snippet @@ -5,6 +5,8 @@ 1. **[BART](https://huggingface.co/docs/transformers/model_doc/bart)** (from Facebook) released with the paper [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461) by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer. 1. **[BEiT](https://huggingface.co/docs/transformers/model_doc/beit)** (from Microsoft) released with the paper [BEiT: BERT Pre-Training of Image Transformers](https://arxiv.org/abs/2106.08254) by Hangbo Bao, Li Dong, Furu Wei. 1. **[BERT](https://huggingface.co/docs/transformers/model_doc/bert)** (from Google) released with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. +1. **[Blenderbot](https://huggingface.co/docs/transformers/model_doc/blenderbot)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston. +1. **[BlenderbotSmall](https://huggingface.co/docs/transformers/model_doc/blenderbot-small)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston. 1. **[BLOOM](https://huggingface.co/docs/transformers/model_doc/bloom)** (from BigScience workshop) released by the [BigScience Workshop](https://bigscience.huggingface.co/). 1. **[CamemBERT](https://huggingface.co/docs/transformers/model_doc/camembert)** (from Inria/Facebook/Sorbonne) released with the paper [CamemBERT: a Tasty French Language Model](https://arxiv.org/abs/1911.03894) by Louis Martin*, Benjamin Muller*, Pedro Javier Ortiz Suárez*, Yoann Dupont, Laurent Romary, Éric Villemonte de la Clergerie, Djamé Seddah and Benoît Sagot. 1. **[CLIP](https://huggingface.co/docs/transformers/model_doc/clip)** (from OpenAI) released with the paper [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) by Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever. From ed648094c98a18f3c8c97508bf7f44e3eefceb10 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 12 Sep 2023 01:05:31 +0200 Subject: [PATCH 8/9] Improve `addPastKeyValues` function --- src/models.js | 70 +++++++++++++++++++++++---------------------------- 1 file changed, 31 insertions(+), 39 deletions(-) diff --git a/src/models.js b/src/models.js index 0db929f10..acc7742d9 100644 --- a/src/models.js +++ b/src/models.js @@ -308,7 +308,6 @@ function boolTensor(value) { * @private */ async function seq2seqForward(self, model_inputs) { - const add_decoder_pkv = self.add_decoder_pkv ?? true; let { encoder_outputs, past_key_values } = model_inputs; @@ -325,7 +324,7 @@ async function seq2seqForward(self, model_inputs) { if (self.decoder_merged_session.inputNames.includes('encoder_attention_mask')) { decoderFeeds.encoder_attention_mask = model_inputs.attention_mask } - self.addPastKeyValues(decoderFeeds, past_key_values, add_decoder_pkv); + self.addPastKeyValues(decoderFeeds, past_key_values); const decoderResults = await sessionRun(self.decoder_merged_session, decoderFeeds); let logits = decoderResults.logits; @@ -1182,57 +1181,50 @@ export class PreTrainedModel extends Callable { * * @param {Object} decoderFeeds The decoder feeds object to add past key values to. * @param {Object} pastKeyValues An object containing past key values. - * @param {boolean} [hasDecoder=false] Whether the model has a decoder. */ - addPastKeyValues(decoderFeeds, pastKeyValues, hasDecoder = false) { + addPastKeyValues(decoderFeeds, pastKeyValues) { if (pastKeyValues) { Object.assign(decoderFeeds, pastKeyValues) } else { // TODO support batches (i.e., batch_size > 1) - if (hasDecoder) { + if (this.config.is_encoder_decoder) { // @ts-ignore let encoder_dims = [1, this.num_encoder_heads, 0, this.encoder_dim_kv]; - // @ts-ignore - for (let i = 0; i < this.num_encoder_layers; ++i) { - decoderFeeds[`past_key_values.${i}.encoder.key`] = new Tensor('float32', [], encoder_dims) - decoderFeeds[`past_key_values.${i}.encoder.value`] = new Tensor('float32', [], encoder_dims) - } - // @ts-ignore let decoder_dims = [1, this.num_decoder_heads, 0, this.decoder_dim_kv]; // @ts-ignore for (let i = 0; i < this.num_decoder_layers; ++i) { + decoderFeeds[`past_key_values.${i}.encoder.key`] = new Tensor('float32', [], encoder_dims) + decoderFeeds[`past_key_values.${i}.encoder.value`] = new Tensor('float32', [], encoder_dims) decoderFeeds[`past_key_values.${i}.decoder.key`] = new Tensor('float32', [], decoder_dims) decoderFeeds[`past_key_values.${i}.decoder.value`] = new Tensor('float32', [], decoder_dims) } + } else if (this.config.multi_query) { // e.g., for `gpt_bigcode` + // @ts-ignore + let dims = [1, 0, 2 * this.dim_kv] + // @ts-ignore + for (let i = 0; i < this.num_layers; ++i) { + decoderFeeds[`past_key_values.${i}.key_value`] = new Tensor('float32', [], dims) + } + } else if (this.config.model_type === 'bloom') { + // NOTE: Custom implementation for Bloom - } else { - if (this.config.multi_query) { - // @ts-ignore - let dims = [1, 0, 2 * this.dim_kv] - // @ts-ignore - for (let i = 0; i < this.num_layers; ++i) { - decoderFeeds[`past_key_values.${i}.key_value`] = new Tensor('float32', [], dims) - } - } else if (this.config.model_type === 'bloom') { - // Custom implementation for Bloom - // @ts-ignore - let keyDims = [1 * this.num_heads, this.dim_kv, 0] // [batch_size x num_heads,64,past_sequence_length] - // @ts-ignore - let valueDims = [1 * this.num_heads, 0, this.dim_kv] // [batch_size x num_heads,past_sequence_length,64] - // @ts-ignore - for (let i = 0; i < this.num_layers; ++i) { - decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], keyDims) - decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], valueDims) - } - } else { - // @ts-ignore - let dims = [1, this.num_heads, 0, this.dim_kv] - // @ts-ignore - for (let i = 0; i < this.num_layers; ++i) { - decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], dims) - decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], dims) - } + // @ts-ignore + let keyDims = [1 * this.num_heads, this.dim_kv, 0] // [batch_size x num_heads,64,past_sequence_length] + // @ts-ignore + let valueDims = [1 * this.num_heads, 0, this.dim_kv] // [batch_size x num_heads,past_sequence_length,64] + // @ts-ignore + for (let i = 0; i < this.num_layers; ++i) { + decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], keyDims) + decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], valueDims) + } + } else { // Decoder-only + // @ts-ignore + let dims = [1, this.num_heads, 0, this.dim_kv] + // @ts-ignore + for (let i = 0; i < this.num_layers; ++i) { + decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], dims) + decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], dims) } } } @@ -2546,7 +2538,7 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { */ export class VisionEncoderDecoderModel extends PreTrainedModel { main_input_name = 'pixel_values'; - add_decoder_pkv = false; + // add_decoder_pkv = false; /** * Creates a new instance of the `VisionEncoderDecoderModel` class. From dfd18e2f65be5e31e49420db97b636df12230cc3 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 12 Sep 2023 01:17:41 +0200 Subject: [PATCH 9/9] Allow skipping of adding encoder past key values --- src/models.js | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/models.js b/src/models.js index acc7742d9..fa8e89d5f 100644 --- a/src/models.js +++ b/src/models.js @@ -1187,7 +1187,8 @@ export class PreTrainedModel extends Callable { Object.assign(decoderFeeds, pastKeyValues) } else { // TODO support batches (i.e., batch_size > 1) - if (this.config.is_encoder_decoder) { + // @ts-ignore + if (this.config.is_encoder_decoder && (this.add_encoder_pkv ?? true)) { // @ts-ignore let encoder_dims = [1, this.num_encoder_heads, 0, this.encoder_dim_kv]; // @ts-ignore @@ -2538,7 +2539,7 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { */ export class VisionEncoderDecoderModel extends PreTrainedModel { main_input_name = 'pixel_values'; - // add_decoder_pkv = false; + add_encoder_pkv = false; /** * Creates a new instance of the `VisionEncoderDecoderModel` class.