Skip to content

Commit

Permalink
Expose Web LoRA API.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 631170793
  • Loading branch information
MediaPipe Team authored and copybara-github committed May 6, 2024
1 parent c930b9b commit fccb895
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 10 deletions.
8 changes: 6 additions & 2 deletions mediapipe/tasks/web/genai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@
*/

import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver';
import {LlmInference as LlmInferenceImpl} from '../../../tasks/web/genai/llm_inference/llm_inference';
import {
LlmInference as LlmInferenceImpl,
LoraModel as LoraModelImpl,
} from '../../../tasks/web/genai/llm_inference/llm_inference';

// Declare the variables locally so that Rollup in OSS includes them explicitly
// as exports.
const FilesetResolver = FilesetResolverImpl;
const LlmInference = LlmInferenceImpl;
const LoraModel = LoraModelImpl;

export {FilesetResolver, LlmInference};
export {FilesetResolver, LlmInference, LoraModel};
141 changes: 133 additions & 8 deletions mediapipe/tasks/web/genai/llm_inference/llm_inference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ const OUTPUT_END_STREAM = 'text_end';
const TOKEN_COST_INPUT_STREAM = 'token_cost_in';
const TOKEN_COST_OUTPUT_STREAM = 'token_cost_out';

const LORA_MODEL_ID_TO_APPLY_INPUT_STREAM = 'lora_model_id_to_apply_in';
const LORA_MODEL_REF_INPUT_STREAM = 'lora_model_ref_in';
const LORA_MODEL_ID_TO_LOAD_INPUT_STREAM = 'lora_model_id_to_load_in';

const DEFAULT_MAX_TOKENS = 512;
const DEFAULT_TOP_K = 1;
const DEFAULT_TOP_P = 1.0;
Expand All @@ -91,6 +95,18 @@ const MAX_BUFFER_SIZE_FOR_LLM = 524550144;
// Amount of the max WebGPU buffer binding size required for LLM models.
const MAX_STORAGE_BUFFER_BINDING_SIZE_FOR_LLM = 524550144;

/**
* The LoRA model to be used for `generateResponse()` of a LLM Inference task.
*/
export class LoraModel {
private static nextLoraModelId = 0;
readonly loraModelId: number;
constructor(readonly owner: LlmInference) {
this.loraModelId = LoraModel.nextLoraModelId;
LoraModel.nextLoraModelId++;
}
}

/**
* Performs LLM Inference on text.
*/
Expand Down Expand Up @@ -300,6 +316,9 @@ export class LlmInference extends TaskRunner {
if (options.randomSeed) {
this.samplerParams.setSeed(options.randomSeed);
}
if ('loraRanks' in options) {
this.options.setLoraRanksList(options.loraRanks ?? []);
}

let onFinishedLoadingData!: () => void;
const finishedLoadingDataPromise = new Promise<void>((resolve, reject) => {
Expand Down Expand Up @@ -398,24 +417,69 @@ export class LlmInference extends TaskRunner {
text: string,
progressListener: ProgressListener,
): Promise<string>;
/**
* Performs LLM Inference on the provided text and waits
* asynchronously for the response. Only one call to `generateResponse()` can
* run at a time.
*
* @export
* @param text The text to process.
* @param loraModel The LoRA model to apply on the text generation.
* @return The generated text result.
*/
generateResponse(text: string, loraModel: LoraModel): Promise<string>;
/**
* Performs LLM Inference on the provided text and waits
* asynchronously for the response. Only one call to `generateResponse()` can
* run at a time.
*
* @export
* @param text The text to process.
* @param loraModel The LoRA model to apply on the text generation.
* @param progressListener A listener that will be triggered when the task has
* new partial response generated.
* @return The generated text result.
*/
generateResponse(
text: string,
loraModel: LoraModel,
progressListener: ProgressListener,
): Promise<string>;
/** @export */
generateResponse(
text: string,
loraModelOrProgressListener?: ProgressListener | LoraModel,
progressListener?: ProgressListener,
): Promise<string> {
if (this.isProcessing) {
throw new Error('Previous invocation or loading is still ongoing.');
}
if (progressListener) {
this.userProgressListener = progressListener;
}
this.userProgressListener =
typeof loraModelOrProgressListener === 'function'
? loraModelOrProgressListener
: progressListener;
this.generationResult.length = 0;
this.isProcessing = true;
this.graphRunner.addStringToStream(
text,
INPUT_STREAM,
this.getSynctheticTimestamp(),
);
const timeStamp = this.getSynctheticTimestamp();
this.graphRunner.addStringToStream(text, INPUT_STREAM, timeStamp);
if (loraModelOrProgressListener instanceof LoraModel) {
if (loraModelOrProgressListener.owner !== this) {
this.isProcessing = false;
throw new Error(
'The LoRA model was not loaded by this LLM Inference task.',
);
}
this.graphRunner.addUintToStream(
loraModelOrProgressListener.loraModelId,
LORA_MODEL_ID_TO_APPLY_INPUT_STREAM,
timeStamp,
);
} else {
this.graphRunner.addEmptyPacketToStream(
LORA_MODEL_ID_TO_APPLY_INPUT_STREAM,
timeStamp,
);
}
this.finishProcessing();
return new Promise<string>((resolve, reject) => {
this.resolveGeneration = resolve;
Expand Down Expand Up @@ -448,6 +512,52 @@ export class LlmInference extends TaskRunner {
return this.latestTokenCostQueryResult;
}

/**
* Load a LoRA model to the LLM Inference Task and the LoRA model can be used
* by `generateResponse()`. The returned LoRA model can be applied only to the
* current LLM Inference task.
*
* @export
* @param modelAsset The URL to the model or the ArrayBuffer of the model
* content.
* @return A loaded LoRA model.
*/
async loadLoraModel(
modelAsset: string | Uint8Array,
): Promise<LoraModel> {
if (this.isProcessing) {
throw new Error('Cannot load LoRA model while loading or processing.');
}
this.isProcessing = true;
const wasmFileReference =
modelAsset instanceof Uint8Array
? WasmFileReference.loadFromArray(
this.graphRunner.wasmModule,
modelAsset,
)
: await WasmFileReference.loadFromUrl(
this.graphRunner.wasmModule,
modelAsset,
);
const loraModel = new LoraModel(this);
(
this.graphRunner as unknown as LlmGraphRunner
).addWasmFileReferenceToStream(
wasmFileReference,
LORA_MODEL_REF_INPUT_STREAM,
this.getSynctheticTimestamp(),
);
this.graphRunner.addUintToStream(
loraModel.loraModelId,
LORA_MODEL_ID_TO_LOAD_INPUT_STREAM,
this.getSynctheticTimestamp(),
);
this.finishProcessing();
wasmFileReference.free();
this.isProcessing = false;
return loraModel;
}

/**
* Decodes the response from the LLM engine and returns a human-readable
* string.
Expand Down Expand Up @@ -568,6 +678,9 @@ export class LlmInference extends TaskRunner {
const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(INPUT_STREAM);
graphConfig.addInputStream(TOKEN_COST_INPUT_STREAM);
graphConfig.addInputStream(LORA_MODEL_ID_TO_APPLY_INPUT_STREAM);
graphConfig.addInputStream(LORA_MODEL_REF_INPUT_STREAM);
graphConfig.addInputStream(LORA_MODEL_ID_TO_LOAD_INPUT_STREAM);
graphConfig.addInputSidePacket('streaming_reader');
graphConfig.addOutputStream(OUTPUT_STREAM);
graphConfig.addOutputStream(OUTPUT_END_STREAM);
Expand All @@ -577,6 +690,9 @@ export class LlmInference extends TaskRunner {
const tokenizerInputBuildNode = new CalculatorGraphConfig.Node();
tokenizerInputBuildNode.setCalculator('TokenizerInputBuildCalculator');
tokenizerInputBuildNode.addInputStream('PROMPT:' + INPUT_STREAM);
tokenizerInputBuildNode.addInputStream(
'LORA_ID:' + LORA_MODEL_ID_TO_APPLY_INPUT_STREAM,
);
tokenizerInputBuildNode.addOutputStream('prompt');
graphConfig.addNode(tokenizerInputBuildNode);

Expand All @@ -586,6 +702,13 @@ export class LlmInference extends TaskRunner {
modelDataNode.addOutputSidePacket('MODEL_DATA:' + '__side_packet_1');
modelDataNode.addOutputSidePacket('MODEL_TYPE:' + 'model_type');
modelDataNode.addInputSidePacket('READ_DATA_FN:' + 'streaming_reader');
modelDataNode.addInputStream(
'LORA_MODEL_SPAN:' + LORA_MODEL_REF_INPUT_STREAM,
);
modelDataNode.addInputStream(
'LORA_MODEL_ID:' + LORA_MODEL_ID_TO_LOAD_INPUT_STREAM,
);
modelDataNode.addOutputStream('LORA_DATA:' + 'lora_model_data');
graphConfig.addNode(modelDataNode);

// Tokenizer Node
Expand Down Expand Up @@ -645,6 +768,7 @@ export class LlmInference extends TaskRunner {
gpuModelInfo.setEnableFastTuning(true);
gpuModelInfo.setPreferTextureWeights(true);
llmGpuOptions.setGpuModelInfo(gpuModelInfo);
llmGpuOptions.setLoraRanksList(this.options.getLoraRanksList());

const llmParams = new LlmParameters();
const transformerParams = new TransformerParameters();
Expand All @@ -659,6 +783,7 @@ export class LlmInference extends TaskRunner {
llmGpuNode.addNodeOptions(llmGpuOptionsProto);
llmGpuNode.addInputStream('IDS_AND_INPUT_OPTIONS:' + '__stream_0');
llmGpuNode.addInputStream('FINISH:' + 'finish');
llmGpuNode.addInputStream('LORA_DATA:' + 'lora_model_data');
llmGpuNode.addInputSidePacket('MODEL_DATA:' + '__side_packet_1');
llmGpuNode.addOutputStream('DECODED_IDS:' + '__stream_3');
llmGpuNode.addOutputStream('OUTPUT_END:' + '__stream_4');
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,9 @@ export declare interface LlmInferenceOptions extends TaskRunnerOptions {
* Random seed for sampling tokens.
*/
randomSeed?: number;

/**
* The LoRA ranks that will be used during inference.
*/
loraRanks?: number[];
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,7 @@ message LlmInferenceGraphOptions {

// Parameters for the sampler, which is used to pick the winning token.
odml.infra.proto.SamplerParameters sampler_params = 3;

// The LoRA ranks that will be used during inference.
repeated int32 lora_ranks = 4;
}

0 comments on commit fccb895

Please sign in to comment.