Skip to content

Commit

Permalink
Quick fix for wrong number of multi-outputs sometimes when streaming
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 676573033
  • Loading branch information
MediaPipe Team authored and copybara-github committed Sep 19, 2024
1 parent 6e2240d commit 5889842
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions mediapipe/tasks/web/genai/llm_inference/llm_inference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class Deferred<T> {
*/
function roundUpToNearestEven(n: number): number {
if (n === 1) return 1;
return (n + (n % 2));
return n + (n % 2);
}

/**
Expand Down Expand Up @@ -851,6 +851,11 @@ export class LlmInference extends TaskRunner {
// Don't trigger the user progress listener if there are WebGPU errors.
if (this.userProgressListener && this.wgpuErrors.length === 0) {
if (this.isMultiResponseGeneration) {
// TODO: Remove this when we no longer need to have an
// even number of responses in multi-output.
if (decodedText.length > this.options.getNumResponses()) {
decodedText.pop();
}
(this.userProgressListener as MultiResponseProgressListener)(
decodedText,
/* done= */ false,
Expand Down Expand Up @@ -1047,8 +1052,9 @@ export class LlmInference extends TaskRunner {
llmGpuOptions.setSequenceBatchSize(0);
// TODO: Remove this when we no longer need to have an even
// number of responses in multi-output.
llmGpuOptions.setNumOutputHeads(roundUpToNearestEven(
this.options.getNumResponses()));
llmGpuOptions.setNumOutputHeads(
roundUpToNearestEven(this.options.getNumResponses()),
);
llmGpuOptions.setSamplerParams(this.options.getSamplerParams());

const gpuModelInfo = new LlmGpuCalculatorOptions.GpuModelInfo();
Expand Down Expand Up @@ -1097,8 +1103,9 @@ export class LlmInference extends TaskRunner {
const detokenizerOptions = new DetokenizerCalculatorOptions();
// TODO: Remove this when we no longer need to have an even
// number of responses in multi-output.
detokenizerOptions.setNumOutputHeads(roundUpToNearestEven(
this.options.getNumResponses()));
detokenizerOptions.setNumOutputHeads(
roundUpToNearestEven(this.options.getNumResponses()),
);
// No need to set spm model, instead reuse TokenizerCalculator's side input.
detokenizerOptions.addStopTokens('<eos>');
detokenizerOptions.addStopTokens('<|endoftext|>');
Expand Down

0 comments on commit 5889842

Please sign in to comment.