Skip to content

Commit

Permalink
Fix addPastKeyValues for VisionEncoderDecoder models
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Sep 19, 2023
1 parent 841cdb8 commit b0cb176
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -1214,7 +1214,7 @@ export class PreTrainedModel extends Callable {
} else {
// TODO support batches (i.e., batch_size > 1)
// @ts-ignore
if (this.config.is_encoder_decoder) {
if (this.config.is_encoder_decoder && (this.add_encoder_pkv ?? true)) {
// @ts-ignore
let encoder_dims = [1, this.num_encoder_heads, 0, this.encoder_dim_kv];
// @ts-ignore
Expand Down Expand Up @@ -2596,7 +2596,8 @@ export class VisionEncoderDecoderModel extends PreTrainedModel {
// @ts-ignore
const decoder = new decoderModelClass(decoderConfig, decoder_merged_session, generation_config);

if ('num_decoder_layers' in decoder) {
this.add_encoder_pkv = 'num_decoder_layers' in decoder;
if (this.add_encoder_pkv) {
// Decoder is part of an encoder-decoder model
this.num_decoder_layers = decoder.num_decoder_layers;
this.num_decoder_heads = decoder.num_decoder_heads;
Expand Down

0 comments on commit b0cb176

Please sign in to comment.