diff --git a/mediapipe/tasks/web/genai/index.ts b/mediapipe/tasks/web/genai/index.ts index 847f610968..1f229bc790 100644 --- a/mediapipe/tasks/web/genai/index.ts +++ b/mediapipe/tasks/web/genai/index.ts @@ -15,11 +15,15 @@ */ import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; -import {LlmInference as LlmInferenceImpl} from '../../../tasks/web/genai/llm_inference/llm_inference'; +import { + LlmInference as LlmInferenceImpl, + LoraModel as LoraModelImpl, +} from '../../../tasks/web/genai/llm_inference/llm_inference'; // Declare the variables locally so that Rollup in OSS includes them explicitly // as exports. const FilesetResolver = FilesetResolverImpl; const LlmInference = LlmInferenceImpl; +const LoraModel = LoraModelImpl; -export {FilesetResolver, LlmInference}; +export {FilesetResolver, LlmInference, LoraModel}; diff --git a/mediapipe/tasks/web/genai/llm_inference/llm_inference.ts b/mediapipe/tasks/web/genai/llm_inference/llm_inference.ts index b1808569bd..674e17a938 100644 --- a/mediapipe/tasks/web/genai/llm_inference/llm_inference.ts +++ b/mediapipe/tasks/web/genai/llm_inference/llm_inference.ts @@ -74,6 +74,10 @@ const OUTPUT_END_STREAM = 'text_end'; const TOKEN_COST_INPUT_STREAM = 'token_cost_in'; const TOKEN_COST_OUTPUT_STREAM = 'token_cost_out'; +const LORA_MODEL_ID_TO_APPLY_INPUT_STREAM = 'lora_model_id_to_apply_in'; +const LORA_MODEL_REF_INPUT_STREAM = 'lora_model_ref_in'; +const LORA_MODEL_ID_TO_LOAD_INPUT_STREAM = 'lora_model_id_to_load_in'; + const DEFAULT_MAX_TOKENS = 512; const DEFAULT_TOP_K = 1; const DEFAULT_TOP_P = 1.0; @@ -91,6 +95,18 @@ const MAX_BUFFER_SIZE_FOR_LLM = 524550144; // Amount of the max WebGPU buffer binding size required for LLM models. const MAX_STORAGE_BUFFER_BINDING_SIZE_FOR_LLM = 524550144; +/** + * The LoRA model to be used for `generateResponse()` of a LLM Inference task. + */ +export class LoraModel { + private static nextLoraModelId = 0; + readonly loraModelId: number; + constructor(readonly owner: LlmInference) { + this.loraModelId = LoraModel.nextLoraModelId; + LoraModel.nextLoraModelId++; + } +} + /** * Performs LLM Inference on text. */ @@ -300,6 +316,9 @@ export class LlmInference extends TaskRunner { if (options.randomSeed) { this.samplerParams.setSeed(options.randomSeed); } + if ('loraRanks' in options) { + this.options.setLoraRanksList(options.loraRanks ?? []); + } let onFinishedLoadingData!: () => void; const finishedLoadingDataPromise = new Promise((resolve, reject) => { @@ -398,24 +417,69 @@ export class LlmInference extends TaskRunner { text: string, progressListener: ProgressListener, ): Promise; + /** + * Performs LLM Inference on the provided text and waits + * asynchronously for the response. Only one call to `generateResponse()` can + * run at a time. + * + * @export + * @param text The text to process. + * @param loraModel The LoRA model to apply on the text generation. + * @return The generated text result. + */ + generateResponse(text: string, loraModel: LoraModel): Promise; + /** + * Performs LLM Inference on the provided text and waits + * asynchronously for the response. Only one call to `generateResponse()` can + * run at a time. + * + * @export + * @param text The text to process. + * @param loraModel The LoRA model to apply on the text generation. + * @param progressListener A listener that will be triggered when the task has + * new partial response generated. + * @return The generated text result. + */ + generateResponse( + text: string, + loraModel: LoraModel, + progressListener: ProgressListener, + ): Promise; /** @export */ generateResponse( text: string, + loraModelOrProgressListener?: ProgressListener | LoraModel, progressListener?: ProgressListener, ): Promise { if (this.isProcessing) { throw new Error('Previous invocation or loading is still ongoing.'); } - if (progressListener) { - this.userProgressListener = progressListener; - } + this.userProgressListener = + typeof loraModelOrProgressListener === 'function' + ? loraModelOrProgressListener + : progressListener; this.generationResult.length = 0; this.isProcessing = true; - this.graphRunner.addStringToStream( - text, - INPUT_STREAM, - this.getSynctheticTimestamp(), - ); + const timeStamp = this.getSynctheticTimestamp(); + this.graphRunner.addStringToStream(text, INPUT_STREAM, timeStamp); + if (loraModelOrProgressListener instanceof LoraModel) { + if (loraModelOrProgressListener.owner !== this) { + this.isProcessing = false; + throw new Error( + 'The LoRA model was not loaded by this LLM Inference task.', + ); + } + this.graphRunner.addUintToStream( + loraModelOrProgressListener.loraModelId, + LORA_MODEL_ID_TO_APPLY_INPUT_STREAM, + timeStamp, + ); + } else { + this.graphRunner.addEmptyPacketToStream( + LORA_MODEL_ID_TO_APPLY_INPUT_STREAM, + timeStamp, + ); + } this.finishProcessing(); return new Promise((resolve, reject) => { this.resolveGeneration = resolve; @@ -448,6 +512,52 @@ export class LlmInference extends TaskRunner { return this.latestTokenCostQueryResult; } + /** + * Load a LoRA model to the LLM Inference Task and the LoRA model can be used + * by `generateResponse()`. The returned LoRA model can be applied only to the + * current LLM Inference task. + * + * @export + * @param modelAsset The URL to the model or the ArrayBuffer of the model + * content. + * @return A loaded LoRA model. + */ + async loadLoraModel( + modelAsset: string | Uint8Array, + ): Promise { + if (this.isProcessing) { + throw new Error('Cannot load LoRA model while loading or processing.'); + } + this.isProcessing = true; + const wasmFileReference = + modelAsset instanceof Uint8Array + ? WasmFileReference.loadFromArray( + this.graphRunner.wasmModule, + modelAsset, + ) + : await WasmFileReference.loadFromUrl( + this.graphRunner.wasmModule, + modelAsset, + ); + const loraModel = new LoraModel(this); + ( + this.graphRunner as unknown as LlmGraphRunner + ).addWasmFileReferenceToStream( + wasmFileReference, + LORA_MODEL_REF_INPUT_STREAM, + this.getSynctheticTimestamp(), + ); + this.graphRunner.addUintToStream( + loraModel.loraModelId, + LORA_MODEL_ID_TO_LOAD_INPUT_STREAM, + this.getSynctheticTimestamp(), + ); + this.finishProcessing(); + wasmFileReference.free(); + this.isProcessing = false; + return loraModel; + } + /** * Decodes the response from the LLM engine and returns a human-readable * string. @@ -568,6 +678,9 @@ export class LlmInference extends TaskRunner { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); graphConfig.addInputStream(TOKEN_COST_INPUT_STREAM); + graphConfig.addInputStream(LORA_MODEL_ID_TO_APPLY_INPUT_STREAM); + graphConfig.addInputStream(LORA_MODEL_REF_INPUT_STREAM); + graphConfig.addInputStream(LORA_MODEL_ID_TO_LOAD_INPUT_STREAM); graphConfig.addInputSidePacket('streaming_reader'); graphConfig.addOutputStream(OUTPUT_STREAM); graphConfig.addOutputStream(OUTPUT_END_STREAM); @@ -577,6 +690,9 @@ export class LlmInference extends TaskRunner { const tokenizerInputBuildNode = new CalculatorGraphConfig.Node(); tokenizerInputBuildNode.setCalculator('TokenizerInputBuildCalculator'); tokenizerInputBuildNode.addInputStream('PROMPT:' + INPUT_STREAM); + tokenizerInputBuildNode.addInputStream( + 'LORA_ID:' + LORA_MODEL_ID_TO_APPLY_INPUT_STREAM, + ); tokenizerInputBuildNode.addOutputStream('prompt'); graphConfig.addNode(tokenizerInputBuildNode); @@ -586,6 +702,13 @@ export class LlmInference extends TaskRunner { modelDataNode.addOutputSidePacket('MODEL_DATA:' + '__side_packet_1'); modelDataNode.addOutputSidePacket('MODEL_TYPE:' + 'model_type'); modelDataNode.addInputSidePacket('READ_DATA_FN:' + 'streaming_reader'); + modelDataNode.addInputStream( + 'LORA_MODEL_SPAN:' + LORA_MODEL_REF_INPUT_STREAM, + ); + modelDataNode.addInputStream( + 'LORA_MODEL_ID:' + LORA_MODEL_ID_TO_LOAD_INPUT_STREAM, + ); + modelDataNode.addOutputStream('LORA_DATA:' + 'lora_model_data'); graphConfig.addNode(modelDataNode); // Tokenizer Node @@ -645,6 +768,7 @@ export class LlmInference extends TaskRunner { gpuModelInfo.setEnableFastTuning(true); gpuModelInfo.setPreferTextureWeights(true); llmGpuOptions.setGpuModelInfo(gpuModelInfo); + llmGpuOptions.setLoraRanksList(this.options.getLoraRanksList()); const llmParams = new LlmParameters(); const transformerParams = new TransformerParameters(); @@ -659,6 +783,7 @@ export class LlmInference extends TaskRunner { llmGpuNode.addNodeOptions(llmGpuOptionsProto); llmGpuNode.addInputStream('IDS_AND_INPUT_OPTIONS:' + '__stream_0'); llmGpuNode.addInputStream('FINISH:' + 'finish'); + llmGpuNode.addInputStream('LORA_DATA:' + 'lora_model_data'); llmGpuNode.addInputSidePacket('MODEL_DATA:' + '__side_packet_1'); llmGpuNode.addOutputStream('DECODED_IDS:' + '__stream_3'); llmGpuNode.addOutputStream('OUTPUT_END:' + '__stream_4'); diff --git a/mediapipe/tasks/web/genai/llm_inference/llm_inference_options.d.ts b/mediapipe/tasks/web/genai/llm_inference/llm_inference_options.d.ts index d1f69439cf..384370663a 100644 --- a/mediapipe/tasks/web/genai/llm_inference/llm_inference_options.d.ts +++ b/mediapipe/tasks/web/genai/llm_inference/llm_inference_options.d.ts @@ -70,4 +70,9 @@ export declare interface LlmInferenceOptions extends TaskRunnerOptions { * Random seed for sampling tokens. */ randomSeed?: number; + + /** + * The LoRA ranks that will be used during inference. + */ + loraRanks?: number[]; } diff --git a/mediapipe/tasks/web/genai/llm_inference/proto/llm_inference_graph_options.proto b/mediapipe/tasks/web/genai/llm_inference/proto/llm_inference_graph_options.proto index ae8ceb96fd..51ff1b51c9 100644 --- a/mediapipe/tasks/web/genai/llm_inference/proto/llm_inference_graph_options.proto +++ b/mediapipe/tasks/web/genai/llm_inference/proto/llm_inference_graph_options.proto @@ -33,4 +33,7 @@ message LlmInferenceGraphOptions { // Parameters for the sampler, which is used to pick the winning token. odml.infra.proto.SamplerParameters sampler_params = 3; + + // The LoRA ranks that will be used during inference. + repeated int32 lora_ranks = 4; }