Skip to content

Commit

Permalink
Add wait for weight upload flag.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 682071463
  • Loading branch information
MediaPipe Team authored and copybara-github committed Oct 3, 2024
1 parent cf4857f commit dab7439
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 1 deletion.
5 changes: 5 additions & 0 deletions mediapipe/tasks/cc/genai/inference/c/llm_inference_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ typedef struct {
// Optional setting for the number of draft tokens to generate when using
// speculative decoding. Setting to 0 will disable speculative decoding.
size_t num_draft_tokens;

// If true, waits for weights to finish uploading when initializing. Otherwise
// initialization may finish before weights have finished uploading which
// might push some of the weight upload time into input processing.
bool wait_for_weight_uploads;
} LlmModelSettings;

// LlmSessionConfig configures how to execute the model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ import MediaPipeTasksGenAIC
supported_lora_ranks: supportedLoraRanks.baseAddress,
max_top_k: options.maxTopk,
llm_activation_data_type: options.activationDataType.activationDataTypeC,
num_draft_tokens: 0)
num_draft_tokens: 0,
wait_for_weight_uploads: options.waitForWeightUploads)
return try LlmTaskRunner(modelSettings: modelSetting)
}
}
Expand Down Expand Up @@ -224,6 +225,11 @@ extension LlmInference {
/// The activation data type for the model.
@objc public var activationDataType: ActivationDataType = .default

/// If true, waits for weights to finish uploading when initializing. Otherwise initialization
/// may finish before weights have finished uploading which might push some of the weight upload
/// time into input processing.
@objc public var waitForWeightUploads: Bool = false

/// Creates a new instance of `Options` with the given `modelPath` and default values of
/// `maxTokens`, `maxTopk`, `supportedLoraRanks` and `activationDataType`.
/// This function is only intended to be used from Objective C.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ LlmModelSettings ParseModelSettings(void* bytes, int size) {
}
output.llm_activation_data_type = kLlmActivationDataTypeDefault;
output.num_draft_tokens = 0;
output.wait_for_weight_uploads = false;
return output;
}

Expand Down

0 comments on commit dab7439

Please sign in to comment.