diff --git a/mediapipe/tasks/web/genai/llm_inference/llm_inference.ts b/mediapipe/tasks/web/genai/llm_inference/llm_inference.ts index 0b3295bffa..b2a11c1861 100644 --- a/mediapipe/tasks/web/genai/llm_inference/llm_inference.ts +++ b/mediapipe/tasks/web/genai/llm_inference/llm_inference.ts @@ -148,7 +148,7 @@ class Deferred { */ function roundUpToNearestEven(n: number): number { if (n === 1) return 1; - return (n + (n % 2)); + return n + (n % 2); } /** @@ -851,6 +851,11 @@ export class LlmInference extends TaskRunner { // Don't trigger the user progress listener if there are WebGPU errors. if (this.userProgressListener && this.wgpuErrors.length === 0) { if (this.isMultiResponseGeneration) { + // TODO: Remove this when we no longer need to have an + // even number of responses in multi-output. + if (decodedText.length > this.options.getNumResponses()) { + decodedText.pop(); + } (this.userProgressListener as MultiResponseProgressListener)( decodedText, /* done= */ false, @@ -1047,8 +1052,9 @@ export class LlmInference extends TaskRunner { llmGpuOptions.setSequenceBatchSize(0); // TODO: Remove this when we no longer need to have an even // number of responses in multi-output. - llmGpuOptions.setNumOutputHeads(roundUpToNearestEven( - this.options.getNumResponses())); + llmGpuOptions.setNumOutputHeads( + roundUpToNearestEven(this.options.getNumResponses()), + ); llmGpuOptions.setSamplerParams(this.options.getSamplerParams()); const gpuModelInfo = new LlmGpuCalculatorOptions.GpuModelInfo(); @@ -1097,8 +1103,9 @@ export class LlmInference extends TaskRunner { const detokenizerOptions = new DetokenizerCalculatorOptions(); // TODO: Remove this when we no longer need to have an even // number of responses in multi-output. - detokenizerOptions.setNumOutputHeads(roundUpToNearestEven( - this.options.getNumResponses())); + detokenizerOptions.setNumOutputHeads( + roundUpToNearestEven(this.options.getNumResponses()), + ); // No need to set spm model, instead reuse TokenizerCalculator's side input. detokenizerOptions.addStopTokens(''); detokenizerOptions.addStopTokens('<|endoftext|>');