diff --git a/mediapipe/tasks/cc/genai/inference/c/llm_inference_engine.h b/mediapipe/tasks/cc/genai/inference/c/llm_inference_engine.h index 8b3a8e1468..b4d33cb618 100644 --- a/mediapipe/tasks/cc/genai/inference/c/llm_inference_engine.h +++ b/mediapipe/tasks/cc/genai/inference/c/llm_inference_engine.h @@ -96,6 +96,11 @@ typedef struct { // Optional setting for the number of draft tokens to generate when using // speculative decoding. Setting to 0 will disable speculative decoding. size_t num_draft_tokens; + + // If true, waits for weights to finish uploading when initializing. Otherwise + // initialization may finish before weights have finished uploading which + // might push some of the weight upload time into input processing. + bool wait_for_weight_uploads; } LlmModelSettings; // LlmSessionConfig configures how to execute the model. diff --git a/mediapipe/tasks/ios/genai/inference/sources/LlmInference.swift b/mediapipe/tasks/ios/genai/inference/sources/LlmInference.swift index 55ac6ddfcf..a2c7aa349e 100644 --- a/mediapipe/tasks/ios/genai/inference/sources/LlmInference.swift +++ b/mediapipe/tasks/ios/genai/inference/sources/LlmInference.swift @@ -66,7 +66,8 @@ import MediaPipeTasksGenAIC supported_lora_ranks: supportedLoraRanks.baseAddress, max_top_k: options.maxTopk, llm_activation_data_type: options.activationDataType.activationDataTypeC, - num_draft_tokens: 0) + num_draft_tokens: 0, + wait_for_weight_uploads: options.waitForWeightUploads) return try LlmTaskRunner(modelSettings: modelSetting) } } @@ -224,6 +225,11 @@ extension LlmInference { /// The activation data type for the model. @objc public var activationDataType: ActivationDataType = .default + /// If true, waits for weights to finish uploading when initializing. Otherwise initialization + /// may finish before weights have finished uploading which might push some of the weight upload + /// time into input processing. + @objc public var waitForWeightUploads: Bool = false + /// Creates a new instance of `Options` with the given `modelPath` and default values of /// `maxTokens`, `maxTopk`, `supportedLoraRanks` and `activationDataType`. /// This function is only intended to be used from Objective C. diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/llm.cc b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/llm.cc index b8648731e6..18d5ceb2c8 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/llm.cc +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/llm.cc @@ -58,6 +58,7 @@ LlmModelSettings ParseModelSettings(void* bytes, int size) { } output.llm_activation_data_type = kLlmActivationDataTypeDefault; output.num_draft_tokens = 0; + output.wait_for_weight_uploads = false; return output; }