From ea2dab19edabd0bef6a1de57bb84d902b1976a2e Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Wed, 27 Nov 2024 11:12:30 -0800 Subject: [PATCH] Revert "fix janus model" This reverts commit ab2bd3e137e5e49112662571bd6e8962c32eb74f. --- src/models.js | 41 +++++++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/src/models.js b/src/models.js index 2ac75d2d5..7133e64d5 100644 --- a/src/models.js +++ b/src/models.js @@ -44,10 +44,10 @@ import { } from './configs.js'; import { - createInferenceSession, deviceToExecutionProviders, - isONNXProxy, + createInferenceSession, isONNXTensor, + isONNXProxy, } from './backends/onnx.js'; import { DATA_TYPES, @@ -75,48 +75,49 @@ import { } from './utils/constants.js'; import { - ClassifierFreeGuidanceLogitsProcessor, + LogitsProcessorList, ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, - LogitsProcessorList, - MinLengthLogitsProcessor, - MinNewTokensLengthLogitsProcessor, - NoBadWordsLogitsProcessor, + SuppressTokensAtBeginLogitsProcessor, + WhisperTimeStampLogitsProcessor, NoRepeatNGramLogitsProcessor, RepetitionPenaltyLogitsProcessor, - SuppressTokensAtBeginLogitsProcessor, + NoBadWordsLogitsProcessor, + MinLengthLogitsProcessor, + MinNewTokensLengthLogitsProcessor, + TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, - WhisperTimeStampLogitsProcessor, + ClassifierFreeGuidanceLogitsProcessor, } from './generation/logits_process.js'; import { GenerationConfig, } from './generation/configuration_utils.js'; -import { RawImage } from './utils/image.js'; import { cat, - full, - full_like, mean, + zeros, + zeros_like, ones, ones_like, + full, + full_like, stack, std_mean, Tensor, - zeros, - zeros_like, } from './utils/tensor.js'; +import { RawImage } from './utils/image.js'; -import { apis } from './env.js'; -import { LogitsSampler } from './generation/logits_sampler.js'; -import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js'; import { dynamic_time_warping, max, medianFilter } from './utils/maths.js'; +import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js'; +import { LogitsSampler } from './generation/logits_sampler.js'; +import { apis } from './env.js'; -import { whisper_language_to_code } from './models/whisper/common_whisper.js'; import { WhisperGenerationConfig } from './models/whisper/generation_whisper.js'; +import { whisper_language_to_code } from './models/whisper/common_whisper.js'; ////////////////////////////////////////////////// // Model types: used internally @@ -770,6 +771,10 @@ function multimodality_prepare_inputs_for_generation(self, input_ids, model_inpu } } + if (has_past_key_values || !model_inputs.pixel_values) { + model_inputs.pixel_values = full([0, 0, 3, 384, 384], 1.0); + } + if (has_past_key_values) { const num_img_tokens = 0; const num_text_tokens = 1;