Skip to content

Commit

Permalink
Add vision modality to the C API
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705157449
  • Loading branch information
schmidt-sebastian authored and copybara-github committed Dec 11, 2024
1 parent 45d2f95 commit adc16ae
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 1 deletion.
18 changes: 18 additions & 0 deletions mediapipe/tasks/cc/genai/inference/c/llm_inference_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ typedef struct {
// Path to the model artifact.
const char* model_path;

// Path to the vision encoder to use for vision modality. Optional.
const char* vision_encoder_path;

// Path to the vision adapter to use for vision modality. Optional.
const char* vision_adapter_path;

// Directory path for storing model related tokenizer and cache weights. the
// user is responsible for providing the directory that can be writable by the
// program.
Expand Down Expand Up @@ -121,6 +127,13 @@ typedef struct {
// Path to the LoRA tflite flatbuffer file. Optional.
// This is only compatible with GPU models.
const char* lora_path;

// Whether to configure the graph to include the token cost calculator,
// which allows users to only compute the cost of a prompt.
bool include_token_cost_calculator;

// Whether to configure the graph to include the vision modality.
bool enable_vision_modality;
} LlmSessionConfig;

// LlmResponseContext is the return type for
Expand Down Expand Up @@ -166,6 +179,11 @@ ODML_EXPORT void LlmInferenceEngine_Session_Delete(
ODML_EXPORT int LlmInferenceEngine_Session_AddQueryChunk(
LlmInferenceEngine_Session* session, const char* input, char** error_msg);

// Adds an SKBitmap to the session.
ODML_EXPORT int LlmInferenceEngine_Session_AddImage(
LlmInferenceEngine_Session* session, const void* sk_bitmap,
char** error_msg);

// Return the generated output based on the previously added query chunks in
// sync mode.
ODML_EXPORT LlmResponseContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,13 @@ int LlmInferenceEngine_Session_AddQueryChunk(
return 0;
}

ODML_EXPORT int LlmInferenceEngine_Session_AddImage(
LlmInferenceEngine_Session* session, const void* sk_bitmap,
char** error_msg) {
*error_msg = strdup("Not implemented");
return 12;
}

LlmResponseContext LlmInferenceEngine_Session_PredictSync(
LlmInferenceEngine_Session* session) {
LlmInferenceEngine_Session_PredictAsync(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ extension LlmInference {
topp: options.topp,
temperature: options.temperature,
random_seed: options.randomSeed,
lora_path: nil)
lora_path: nil,
include_token_cost_calculator: true,
enable_vision_modality: false)

/// If `loraPath` is != nil, modify session config with the corresponding C string and invoke
/// the method to create session runner within the scope where the C String of the `loraPath`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ import MediaPipeTasksGenAIC
try options.supportedLoraRanks.withUnsafeMutableBufferPointer { supportedLoraRanks in
let modelSetting = LlmModelSettings(
model_path: modelPath,
vision_encoder_path: nil,
vision_adapter_path: nil,
cache_dir: cacheDirectory,
max_num_tokens: options.maxTokens,
num_decode_steps_per_sync: numberOfDecodeStepsPerSync,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ LlmModelSettings ParseModelSettings(void* bytes, int size) {

LlmModelSettings output;
output.model_path = strdup(input.model_path().c_str());
output.vision_encoder_path = nullptr;
output.vision_adapter_path = nullptr;
output.cache_dir = strdup(input.cache_dir().c_str());
output.sequence_batch_size = input.sequence_batch_size();
output.num_decode_steps_per_sync = input.num_decode_steps_per_sync();
Expand Down Expand Up @@ -74,6 +76,8 @@ LlmSessionConfig ParseSessionConfig(void* bytes, int size) {
if (input.has_lora_path()) {
output.lora_path = strdup(input.lora_path().c_str());
}
output.include_token_cost_calculator = true;
output.enable_vision_modality = false;
return output;
}

Expand Down

0 comments on commit adc16ae

Please sign in to comment.